Files
orchard/backend/app/auth.py
Mondo Diaz 2acb0aefb2 Add configurable admin password via environment variable
- Add ORCHARD_ADMIN_PASSWORD env var to set initial admin password
- When set, admin user created without forced password change
- Add AWS Secrets Manager support for stage/prod deployments
- Add .env file support for local docker development
- Add Helm chart auth config (adminPassword, existingSecret, secretsManager)

Environments configured:
- Local: .env file or defaults to changeme123
- Feature/dev: orchardtest123 (hardcoded in values-dev.yaml)
- Stage: AWS Secrets Manager (orchard-stage-creds)
- Prod: AWS Secrets Manager (orch-prod-creds)
2026-01-27 17:23:08 +00:00

1224 lines
38 KiB
Python

"""Authentication service for Orchard.
Handles password hashing, session management, API key operations, and JWT validation.
"""
import hashlib
import secrets
import logging
from datetime import datetime, timedelta, timezone
from typing import Optional
from passlib.context import CryptContext
from sqlalchemy.orm import Session
from .models import User, Session as UserSession, APIKey
from .config import get_settings
logger = logging.getLogger(__name__)
# Password hashing context (bcrypt with cost factor 12)
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
# API key prefix
API_KEY_PREFIX = "orch_"
# Session duration (24 hours default)
SESSION_DURATION_HOURS = 24
# Password requirements
MIN_PASSWORD_LENGTH = 8
class PasswordTooShortError(ValueError):
"""Raised when password doesn't meet minimum length requirement."""
pass
def validate_password_strength(password: str) -> None:
"""Validate password meets minimum requirements.
Raises PasswordTooShortError if password is too short.
"""
if not password or len(password) < MIN_PASSWORD_LENGTH:
raise PasswordTooShortError(
f"Password must be at least {MIN_PASSWORD_LENGTH} characters"
)
def hash_password(password: str) -> str:
"""Hash a password using bcrypt."""
return pwd_context.hash(password)
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""Verify a password against its hash."""
return pwd_context.verify(plain_password, hashed_password)
def hash_token(token: str) -> str:
"""Hash a token (session or API key) using SHA256."""
return hashlib.sha256(token.encode()).hexdigest()
def generate_session_token() -> str:
"""Generate a cryptographically secure session token."""
return secrets.token_urlsafe(32)
def generate_api_key() -> str:
"""Generate a new API key with prefix.
Format: orch_<32 random bytes as hex>
"""
random_part = secrets.token_hex(32)
return f"{API_KEY_PREFIX}{random_part}"
class AuthService:
"""Authentication service for user management and session handling."""
def __init__(self, db: Session):
self.db = db
# --- User Operations ---
def create_user(
self,
username: str,
password: Optional[str] = None,
email: Optional[str] = None,
is_admin: bool = False,
must_change_password: bool = False,
) -> User:
"""Create a new user account."""
user = User(
username=username,
password_hash=hash_password(password) if password else None,
email=email,
is_admin=is_admin,
must_change_password=must_change_password,
)
self.db.add(user)
self.db.commit()
self.db.refresh(user)
return user
def get_user_by_username(self, username: str) -> Optional[User]:
"""Get a user by username."""
return self.db.query(User).filter(User.username == username).first()
def get_user_by_id(self, user_id: str) -> Optional[User]:
"""Get a user by ID."""
return self.db.query(User).filter(User.id == user_id).first()
def authenticate_user(self, username: str, password: str) -> Optional[User]:
"""Authenticate a user with username and password.
Returns the user if authentication succeeds, None otherwise.
Uses constant-time comparison to prevent timing-based user enumeration.
"""
user = self.get_user_by_username(username)
# Always perform password verification to prevent timing attacks
# Use a dummy hash if user doesn't exist
dummy_hash = "$2b$12$LQv3c1yqBWVHxkd0LHAkCOYz6TtxMQJqhN8/X4.VTtYA1vQ9S9sXa"
password_hash = user.password_hash if user and user.password_hash else dummy_hash
# Verify password (constant time even if user doesn't exist)
password_valid = verify_password(password, password_hash)
# Check all conditions
if not user:
return None
if not user.password_hash:
return None # OIDC-only user
if not user.is_active:
return None
if not password_valid:
return None
return user
def change_password(self, user: User, new_password: str) -> None:
"""Change a user's password and invalidate all existing sessions."""
validate_password_strength(new_password)
user.password_hash = hash_password(new_password)
user.must_change_password = False
self.db.commit()
# Invalidate all existing sessions for security
self.delete_user_sessions(user)
def update_last_login(self, user: User) -> None:
"""Update the user's last login timestamp."""
user.last_login = datetime.now(timezone.utc)
self.db.commit()
def list_users(self, include_inactive: bool = False) -> list[User]:
"""List all users."""
query = self.db.query(User)
if not include_inactive:
query = query.filter(User.is_active.is_(True))
return query.order_by(User.username).all()
def set_user_active(self, user: User, is_active: bool) -> None:
"""Enable or disable a user account."""
user.is_active = is_active
self.db.commit()
def set_user_admin(self, user: User, is_admin: bool) -> None:
"""Grant or revoke admin privileges."""
user.is_admin = is_admin
self.db.commit()
def reset_user_password(self, user: User, new_password: str) -> None:
"""Reset a user's password (admin action) and invalidate all sessions."""
validate_password_strength(new_password)
user.password_hash = hash_password(new_password)
user.must_change_password = True
self.db.commit()
# Invalidate all existing sessions for security
self.delete_user_sessions(user)
# --- Session Operations ---
def create_session(
self,
user: User,
user_agent: Optional[str] = None,
ip_address: Optional[str] = None,
) -> tuple[UserSession, str]:
"""Create a new session for a user.
Returns a tuple of (session, token) where token is the plaintext
token that should be sent to the client. The token is only returned
once and should be stored securely.
"""
token = generate_session_token()
token_hash = hash_token(token)
session = UserSession(
user_id=user.id,
token_hash=token_hash,
expires_at=datetime.now(timezone.utc)
+ timedelta(hours=SESSION_DURATION_HOURS),
user_agent=user_agent,
ip_address=ip_address,
)
self.db.add(session)
self.db.commit()
self.db.refresh(session)
return session, token
def get_session_by_token(self, token: str) -> Optional[UserSession]:
"""Get a session by its token.
Returns None if the session doesn't exist or has expired.
"""
token_hash = hash_token(token)
session = (
self.db.query(UserSession)
.filter(UserSession.token_hash == token_hash)
.first()
)
if not session:
return None
if session.expires_at < datetime.now(timezone.utc):
# Session has expired, delete it
self.db.delete(session)
self.db.commit()
return None
# Update last accessed time
session.last_accessed = datetime.now(timezone.utc)
self.db.commit()
return session
def delete_session(self, session: UserSession) -> None:
"""Delete a session (logout)."""
self.db.delete(session)
self.db.commit()
def delete_user_sessions(self, user: User) -> int:
"""Delete all sessions for a user. Returns count of deleted sessions."""
count = (
self.db.query(UserSession).filter(UserSession.user_id == user.id).delete()
)
self.db.commit()
return count
def cleanup_expired_sessions(self) -> int:
"""Delete all expired sessions. Returns count of deleted sessions."""
count = (
self.db.query(UserSession)
.filter(UserSession.expires_at < datetime.now(timezone.utc))
.delete()
)
self.db.commit()
return count
# --- API Key Operations ---
def create_api_key(
self,
user: User,
name: str,
description: Optional[str] = None,
scopes: Optional[list[str]] = None,
expires_at: Optional[datetime] = None,
) -> tuple[APIKey, str]:
"""Create a new API key for a user.
Returns a tuple of (api_key, key) where key is the plaintext
API key that should be sent to the client. The key is only returned
once and should be stored securely by the user.
"""
key = generate_api_key()
key_hash = hash_token(key)
api_key = APIKey(
key_hash=key_hash,
name=name,
user_id=user.username, # Legacy field
owner_id=user.id,
description=description,
scopes=scopes or ["read", "write"],
expires_at=expires_at,
)
self.db.add(api_key)
self.db.commit()
self.db.refresh(api_key)
return api_key, key
def get_api_key_by_key(self, key: str) -> Optional[APIKey]:
"""Get an API key by its plaintext key.
Returns None if the key doesn't exist or has expired.
"""
if not key.startswith(API_KEY_PREFIX):
return None
key_hash = hash_token(key)
api_key = self.db.query(APIKey).filter(APIKey.key_hash == key_hash).first()
if not api_key:
return None
# Check expiration
if api_key.expires_at and api_key.expires_at < datetime.now(timezone.utc):
return None
# Update last used time
api_key.last_used = datetime.now(timezone.utc)
self.db.commit()
return api_key
def get_api_key_by_id(self, key_id: str) -> Optional[APIKey]:
"""Get an API key by its ID."""
return self.db.query(APIKey).filter(APIKey.id == key_id).first()
def list_user_api_keys(self, user: User) -> list[APIKey]:
"""List all API keys for a user."""
return (
self.db.query(APIKey)
.filter(APIKey.owner_id == user.id)
.order_by(APIKey.created_at.desc())
.all()
)
def delete_api_key(self, api_key: APIKey) -> None:
"""Delete an API key."""
self.db.delete(api_key)
self.db.commit()
def get_user_from_api_key(self, key: str) -> Optional[User]:
"""Get the user associated with an API key.
Returns None if the key is invalid or the user is inactive.
"""
api_key = self.get_api_key_by_key(key)
if not api_key:
return None
if not api_key.owner_id:
return None
user = self.db.query(User).filter(User.id == api_key.owner_id).first()
if not user or not user.is_active:
return None
return user
def create_default_admin(db: Session) -> Optional[User]:
"""Create the default admin user if no users exist.
Returns the created user, or None if users already exist.
The admin password can be set via ORCHARD_ADMIN_PASSWORD environment variable.
If not set, defaults to 'changeme123' and requires password change on first login.
"""
# Check if any users exist
user_count = db.query(User).count()
if user_count > 0:
return None
settings = get_settings()
# Use configured password or default
password = settings.admin_password if settings.admin_password else "changeme123"
# Only require password change if using the default password
must_change = not settings.admin_password
# Create default admin
auth_service = AuthService(db)
admin = auth_service.create_user(
username="admin",
password=password,
is_admin=True,
must_change_password=must_change,
)
if settings.admin_password:
logger.info("Created default admin user with configured password")
else:
logger.info("Created default admin user with default password (changeme123)")
return admin
# --- JWT Validation ---
def validate_jwt_token(token: str) -> Optional[dict]:
"""Validate a JWT token and return the decoded payload.
Returns None if validation fails or JWT is not configured.
Uses python-jose for JWT operations.
"""
settings = get_settings()
if not settings.jwt_enabled:
return None
try:
from jose import jwt, JWTError, ExpiredSignatureError
from jose.exceptions import JWTClaimsError
except ImportError:
logger.warning("python-jose not installed, JWT authentication disabled")
return None
try:
# Build decode options
decode_options = {}
# Set up key for validation
if settings.jwt_algorithm.startswith("RS"):
# RS256/RS384/RS512 - use JWKS
if not settings.jwt_jwks_url:
logger.error("JWT JWKS URL not configured for RSA algorithm")
return None
try:
import httpx
# Fetch JWKS from the URL
response = httpx.get(settings.jwt_jwks_url, timeout=10.0)
response.raise_for_status()
jwks = response.json()
# Get the key ID from the token header
unverified_header = jwt.get_unverified_header(token)
kid = unverified_header.get("kid")
# Find the matching key
rsa_key = None
for key in jwks.get("keys", []):
if key.get("kid") == kid:
rsa_key = key
break
if not rsa_key:
logger.error(f"No matching key found in JWKS for kid: {kid}")
return None
key = rsa_key
except Exception as e:
logger.error(f"Failed to get signing key from JWKS: {e}")
return None
else:
# HS256/HS384/HS512 - use secret
if not settings.jwt_secret:
logger.error("JWT secret not configured for HMAC algorithm")
return None
key = settings.jwt_secret
# Build decode kwargs
decode_kwargs = {
"algorithms": [settings.jwt_algorithm],
"options": decode_options,
}
# Add issuer validation if configured
if settings.jwt_issuer:
decode_kwargs["issuer"] = settings.jwt_issuer
# Add audience validation if configured
if settings.jwt_audience:
decode_kwargs["audience"] = settings.jwt_audience
# Decode and validate the token
payload = jwt.decode(token, key, **decode_kwargs)
return payload
except ExpiredSignatureError:
logger.debug("JWT token expired")
return None
except JWTClaimsError as e:
logger.debug(f"JWT claims error: {e}")
return None
except JWTError as e:
logger.debug(f"Invalid JWT token: {e}")
return None
except Exception as e:
logger.error(f"JWT validation error: {e}")
return None
def get_or_create_user_from_jwt(db: Session, payload: dict) -> Optional[User]:
"""Get or create a user from JWT payload.
Uses the configured username claim to extract the username.
Creates a new user if one doesn't exist (for SSO auto-provisioning).
"""
settings = get_settings()
username = payload.get(settings.jwt_username_claim)
if not username:
logger.warning(f"JWT missing username claim: {settings.jwt_username_claim}")
return None
# Sanitize username (remove domain from email if needed)
if "@" in username and settings.jwt_username_claim == "email":
# Keep full email as username for email-based auth
pass
auth_service = AuthService(db)
user = auth_service.get_user_by_username(username)
if user:
if not user.is_active:
logger.debug(f"JWT user {username} is inactive")
return None
return user
# Auto-provision user from JWT
logger.info(f"Auto-provisioning user from JWT: {username}")
try:
user = auth_service.create_user(
username=username,
password=None, # No password for SSO users
email=payload.get("email"),
is_admin=False,
must_change_password=False,
)
return user
except Exception as e:
logger.error(f"Failed to auto-provision JWT user: {e}")
return None
# --- FastAPI Dependencies ---
from fastapi import Depends, HTTPException, status, Cookie, Header
from .database import get_db
# Cookie name for session token
SESSION_COOKIE_NAME = "orchard_session"
def get_current_user_optional(
db: Session = Depends(get_db),
session_token: Optional[str] = Cookie(None, alias=SESSION_COOKIE_NAME),
authorization: Optional[str] = Header(None),
) -> Optional[User]:
"""Get the current user from session cookie, API key, or JWT token.
Returns None if no valid authentication is provided.
Does not raise an exception for unauthenticated requests.
Authentication methods are tried in order:
1. Session cookie (web UI)
2. API key (Bearer token starting with 'orch_')
3. JWT token (Bearer token that's a valid JWT)
"""
auth_service = AuthService(db)
# First try session cookie (web UI)
if session_token:
session = auth_service.get_session_by_token(session_token)
if session:
user = auth_service.get_user_by_id(str(session.user_id))
if user and user.is_active:
return user
# Then try Bearer token (API key or JWT)
if authorization and authorization.startswith("Bearer "):
token = authorization[7:] # Remove "Bearer " prefix
# Check if it's an API key (starts with orch_)
if token.startswith(API_KEY_PREFIX):
user = auth_service.get_user_from_api_key(token)
if user:
return user
else:
# Try JWT validation
settings = get_settings()
if settings.jwt_enabled:
payload = validate_jwt_token(token)
if payload:
user = get_or_create_user_from_jwt(db, payload)
if user:
return user
return None
def get_current_user(
user: Optional[User] = Depends(get_current_user_optional),
) -> User:
"""Get the current authenticated user.
Raises HTTPException 401 if not authenticated.
"""
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Not authenticated",
headers={"WWW-Authenticate": "Bearer"},
)
return user
def require_admin(
user: User = Depends(get_current_user),
) -> User:
"""Require the current user to be an admin.
Raises HTTPException 403 if user is not an admin.
"""
if not user.is_admin:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Admin privileges required",
)
return user
def get_auth_service(db: Session = Depends(get_db)) -> AuthService:
"""Get an AuthService instance."""
return AuthService(db)
# --- Authorization ---
# Access levels in order of increasing privilege
ACCESS_LEVELS = ["read", "write", "admin"]
def get_access_level_rank(level: str) -> int:
"""Get numeric rank for access level comparison."""
try:
return ACCESS_LEVELS.index(level)
except ValueError:
return -1
def has_sufficient_access(user_level: str, required_level: str) -> bool:
"""Check if user_level is sufficient for required_level.
Access levels are hierarchical: admin > write > read
"""
return get_access_level_rank(user_level) >= get_access_level_rank(required_level)
class AuthorizationService:
"""Service for checking project-level authorization."""
def __init__(self, db: Session):
self.db = db
def get_user_access_level(
self, project_id: str, user: Optional[User]
) -> Optional[str]:
"""Get the user's access level for a project.
Returns the highest access level the user has, or None if no access.
Checks in order:
1. System admin - gets admin access to all projects
2. Project owner (created_by) - gets admin access
3. Explicit permission in access_permissions table
"""
from .models import Project, AccessPermission
# Get the project
project = self.db.query(Project).filter(Project.id == project_id).first()
if not project:
return None
# Anonymous users only get access to public projects
if not user:
return "read" if project.is_public else None
# System admins get admin access everywhere
if user.is_admin:
return "admin"
# Project owner gets admin access
if project.created_by == user.username:
return "admin"
# Check explicit permissions
permission = (
self.db.query(AccessPermission)
.filter(
AccessPermission.project_id == project_id,
AccessPermission.user_id == user.username,
)
.first()
)
if permission:
# Check expiration
if permission.expires_at and permission.expires_at < datetime.now(timezone.utc):
return "read" if project.is_public else None
return permission.level
# Fall back to public access
return "read" if project.is_public else None
def check_access(
self,
project_id: str,
user: Optional[User],
required_level: str,
) -> bool:
"""Check if user has required access level for project."""
user_level = self.get_user_access_level(project_id, user)
if not user_level:
return False
return has_sufficient_access(user_level, required_level)
def grant_access(
self,
project_id: str,
username: str,
level: str,
expires_at: Optional[datetime] = None,
) -> "AccessPermission":
"""Grant access to a user for a project."""
from .models import AccessPermission
# Check if permission already exists
existing = (
self.db.query(AccessPermission)
.filter(
AccessPermission.project_id == project_id,
AccessPermission.user_id == username,
)
.first()
)
if existing:
existing.level = level
existing.expires_at = expires_at
self.db.commit()
return existing
permission = AccessPermission(
project_id=project_id,
user_id=username,
level=level,
expires_at=expires_at,
)
self.db.add(permission)
self.db.commit()
self.db.refresh(permission)
return permission
def revoke_access(self, project_id: str, username: str) -> bool:
"""Revoke a user's access to a project. Returns True if deleted."""
from .models import AccessPermission
count = (
self.db.query(AccessPermission)
.filter(
AccessPermission.project_id == project_id,
AccessPermission.user_id == username,
)
.delete()
)
self.db.commit()
return count > 0
def list_project_permissions(self, project_id: str) -> list:
"""List all permissions for a project."""
from .models import AccessPermission
return (
self.db.query(AccessPermission)
.filter(AccessPermission.project_id == project_id)
.all()
)
def get_authorization_service(db: Session = Depends(get_db)) -> AuthorizationService:
"""Get an AuthorizationService instance."""
return AuthorizationService(db)
class ProjectAccessChecker:
"""Dependency for checking project access in route handlers."""
def __init__(self, required_level: str = "read"):
self.required_level = required_level
def __call__(
self,
project: str,
db: Session = Depends(get_db),
current_user: Optional[User] = Depends(get_current_user_optional),
) -> User:
"""Check if user has required access to project.
Raises 404 if project not found, 403 if insufficient access.
Returns the current user (or None for public read access).
"""
from .models import Project
# Find project by name
proj = db.query(Project).filter(Project.name == project).first()
if not proj:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Project '{project}' not found",
)
auth_service = AuthorizationService(db)
if not auth_service.check_access(str(proj.id), current_user, self.required_level):
if not current_user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Authentication required for private project",
headers={"WWW-Authenticate": "Bearer"},
)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Insufficient permissions. Required: {self.required_level}",
)
return current_user
# Pre-configured access checkers for common use cases
require_project_read = ProjectAccessChecker("read")
require_project_write = ProjectAccessChecker("write")
require_project_admin = ProjectAccessChecker("admin")
def check_project_access(
db: Session,
project_name: str,
user: Optional[User],
required_level: str = "read",
) -> "Project":
"""Check if user has required access to project.
This is a helper function for use in route handlers.
Args:
db: Database session
project_name: Name of the project
user: Current user (can be None for anonymous)
required_level: Required access level (read, write, admin)
Returns:
The Project object if access is granted
Raises:
HTTPException 404: Project not found
HTTPException 401: Authentication required for private project
HTTPException 403: Insufficient permissions
"""
from .models import Project
# Find project by name
project = db.query(Project).filter(Project.name == project_name).first()
if not project:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Project '{project_name}' not found",
)
auth_service = AuthorizationService(db)
if not auth_service.check_access(str(project.id), user, required_level):
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Authentication required for private project",
headers={"WWW-Authenticate": "Bearer"},
)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Insufficient permissions. Required: {required_level}",
)
return project
# --- OIDC Configuration Service ---
class OIDCConfig:
"""OIDC configuration data class."""
def __init__(
self,
enabled: bool = False,
issuer_url: str = "",
client_id: str = "",
client_secret: str = "",
scopes: list[str] = None,
auto_create_users: bool = True,
admin_group: str = "", # Group/role that grants admin access
):
self.enabled = enabled
self.issuer_url = issuer_url.rstrip("/") if issuer_url else ""
self.client_id = client_id
self.client_secret = client_secret
self.scopes = scopes or ["openid", "profile", "email"]
self.auto_create_users = auto_create_users
self.admin_group = admin_group
@property
def discovery_url(self) -> str:
"""Get the OIDC discovery URL."""
if not self.issuer_url:
return ""
return f"{self.issuer_url}/.well-known/openid-configuration"
def to_dict(self) -> dict:
"""Convert to dictionary for storage."""
return {
"enabled": self.enabled,
"issuer_url": self.issuer_url,
"client_id": self.client_id,
"client_secret": self.client_secret,
"scopes": self.scopes,
"auto_create_users": self.auto_create_users,
"admin_group": self.admin_group,
}
@classmethod
def from_dict(cls, data: dict) -> "OIDCConfig":
"""Create from dictionary."""
return cls(
enabled=data.get("enabled", False),
issuer_url=data.get("issuer_url", ""),
client_id=data.get("client_id", ""),
client_secret=data.get("client_secret", ""),
scopes=data.get("scopes", ["openid", "profile", "email"]),
auto_create_users=data.get("auto_create_users", True),
admin_group=data.get("admin_group", ""),
)
class OIDCConfigService:
"""Service for managing OIDC configuration."""
OIDC_CONFIG_KEY = "oidc_config"
def __init__(self, db: Session):
self.db = db
def get_config(self) -> OIDCConfig:
"""Get the current OIDC configuration."""
from .models import AuthSettings
import json
setting = (
self.db.query(AuthSettings)
.filter(AuthSettings.key == self.OIDC_CONFIG_KEY)
.first()
)
if not setting:
return OIDCConfig()
try:
data = json.loads(setting.value)
return OIDCConfig.from_dict(data)
except (json.JSONDecodeError, KeyError):
return OIDCConfig()
def save_config(self, config: OIDCConfig) -> None:
"""Save OIDC configuration."""
from .models import AuthSettings
import json
setting = (
self.db.query(AuthSettings)
.filter(AuthSettings.key == self.OIDC_CONFIG_KEY)
.first()
)
if setting:
setting.value = json.dumps(config.to_dict())
setting.updated_at = datetime.now(timezone.utc)
else:
setting = AuthSettings(
key=self.OIDC_CONFIG_KEY,
value=json.dumps(config.to_dict()),
)
self.db.add(setting)
self.db.commit()
def is_enabled(self) -> bool:
"""Check if OIDC is enabled."""
config = self.get_config()
return config.enabled and bool(config.issuer_url) and bool(config.client_id)
def get_oidc_config_service(db: Session = Depends(get_db)) -> OIDCConfigService:
"""Get an OIDCConfigService instance."""
return OIDCConfigService(db)
# --- OIDC Authentication Flow ---
class OIDCService:
"""Service for OIDC authentication flow."""
def __init__(self, db: Session, config: OIDCConfig):
self.db = db
self.config = config
self._discovery_doc: Optional[dict] = None
def get_discovery_document(self) -> Optional[dict]:
"""Fetch and cache the OIDC discovery document."""
if self._discovery_doc:
return self._discovery_doc
if not self.config.discovery_url:
return None
try:
import httpx
response = httpx.get(self.config.discovery_url, timeout=10.0)
response.raise_for_status()
self._discovery_doc = response.json()
return self._discovery_doc
except Exception as e:
logger.error(f"Failed to fetch OIDC discovery document: {e}")
return None
def get_authorization_url(self, redirect_uri: str, state: str) -> Optional[str]:
"""Generate the OIDC authorization URL."""
discovery = self.get_discovery_document()
if not discovery:
return None
auth_endpoint = discovery.get("authorization_endpoint")
if not auth_endpoint:
logger.error("No authorization_endpoint in discovery document")
return None
import urllib.parse
params = {
"client_id": self.config.client_id,
"response_type": "code",
"scope": " ".join(self.config.scopes),
"redirect_uri": redirect_uri,
"state": state,
}
return f"{auth_endpoint}?{urllib.parse.urlencode(params)}"
def exchange_code_for_tokens(
self, code: str, redirect_uri: str
) -> Optional[dict]:
"""Exchange authorization code for tokens."""
discovery = self.get_discovery_document()
if not discovery:
return None
token_endpoint = discovery.get("token_endpoint")
if not token_endpoint:
logger.error("No token_endpoint in discovery document")
return None
try:
import httpx
response = httpx.post(
token_endpoint,
data={
"grant_type": "authorization_code",
"code": code,
"redirect_uri": redirect_uri,
"client_id": self.config.client_id,
"client_secret": self.config.client_secret,
},
timeout=10.0,
)
response.raise_for_status()
return response.json()
except Exception as e:
logger.error(f"Failed to exchange code for tokens: {e}")
return None
def validate_id_token(self, id_token: str) -> Optional[dict]:
"""Validate and decode the ID token."""
discovery = self.get_discovery_document()
if not discovery:
return None
try:
from jose import jwt, JWTError
import httpx
# Get JWKS
jwks_uri = discovery.get("jwks_uri")
if not jwks_uri:
logger.error("No jwks_uri in discovery document")
return None
response = httpx.get(jwks_uri, timeout=10.0)
response.raise_for_status()
jwks = response.json()
# Get the key ID from the token header
unverified_header = jwt.get_unverified_header(id_token)
kid = unverified_header.get("kid")
# Find the matching key
rsa_key = None
for key in jwks.get("keys", []):
if key.get("kid") == kid:
rsa_key = key
break
if not rsa_key:
logger.error(f"No matching key found in JWKS for kid: {kid}")
return None
# Decode and validate the token
payload = jwt.decode(
id_token,
rsa_key,
algorithms=["RS256"],
audience=self.config.client_id,
issuer=self.config.issuer_url,
)
return payload
except JWTError as e:
logger.error(f"ID token validation failed: {e}")
return None
except Exception as e:
logger.error(f"Error validating ID token: {e}")
return None
def get_or_create_user(self, id_token_claims: dict) -> Optional[User]:
"""Get or create a user from ID token claims."""
# Extract user info from claims
subject = id_token_claims.get("sub")
email = id_token_claims.get("email")
name = id_token_claims.get("name") or id_token_claims.get("preferred_username")
if not subject:
logger.error("No 'sub' claim in ID token")
return None
# Try to find existing user by OIDC subject
user = (
self.db.query(User)
.filter(
User.oidc_subject == subject,
User.oidc_issuer == self.config.issuer_url,
)
.first()
)
if user:
# Update last login
user.last_login = datetime.now(timezone.utc)
self.db.commit()
return user
# Try to find by email and link accounts
if email:
user = self.db.query(User).filter(User.email == email).first()
if user:
# Link OIDC identity to existing user
user.oidc_subject = subject
user.oidc_issuer = self.config.issuer_url
user.last_login = datetime.now(timezone.utc)
self.db.commit()
logger.info(f"Linked OIDC identity to existing user: {user.username}")
return user
# Create new user if auto-creation is enabled
if not self.config.auto_create_users:
logger.warning(f"Auto-creation disabled, rejecting new OIDC user: {subject}")
return None
# Determine username (use email prefix or subject)
username = email.split("@")[0] if email else subject
# Check for username collision
existing = self.db.query(User).filter(User.username == username).first()
if existing:
# Append part of subject to make unique
username = f"{username}_{subject[:8]}"
# Check if user should be admin based on groups/roles
is_admin = False
if self.config.admin_group:
groups = id_token_claims.get("groups", [])
roles = id_token_claims.get("roles", [])
is_admin = (
self.config.admin_group in groups
or self.config.admin_group in roles
)
# Create the user
user = User(
username=username,
email=email,
password_hash=None, # OIDC users don't have passwords
oidc_subject=subject,
oidc_issuer=self.config.issuer_url,
is_admin=is_admin,
must_change_password=False,
)
self.db.add(user)
self.db.commit()
self.db.refresh(user)
logger.info(f"Created new OIDC user: {username} (admin={is_admin})")
return user