1209 lines
38 KiB
Python
1209 lines
38 KiB
Python
"""Authentication service for Orchard.
|
|
|
|
Handles password hashing, session management, API key operations, and JWT validation.
|
|
"""
|
|
|
|
import hashlib
|
|
import secrets
|
|
import logging
|
|
from datetime import datetime, timedelta, timezone
|
|
from typing import Optional
|
|
from passlib.context import CryptContext
|
|
from sqlalchemy.orm import Session
|
|
|
|
from .models import User, Session as UserSession, APIKey
|
|
from .config import get_settings
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
# Password hashing context (bcrypt with cost factor 12)
|
|
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
|
|
|
# API key prefix
|
|
API_KEY_PREFIX = "orch_"
|
|
|
|
# Session duration (24 hours default)
|
|
SESSION_DURATION_HOURS = 24
|
|
|
|
# Password requirements
|
|
MIN_PASSWORD_LENGTH = 8
|
|
|
|
|
|
class PasswordTooShortError(ValueError):
|
|
"""Raised when password doesn't meet minimum length requirement."""
|
|
|
|
pass
|
|
|
|
|
|
def validate_password_strength(password: str) -> None:
|
|
"""Validate password meets minimum requirements.
|
|
|
|
Raises PasswordTooShortError if password is too short.
|
|
"""
|
|
if not password or len(password) < MIN_PASSWORD_LENGTH:
|
|
raise PasswordTooShortError(
|
|
f"Password must be at least {MIN_PASSWORD_LENGTH} characters"
|
|
)
|
|
|
|
|
|
def hash_password(password: str) -> str:
|
|
"""Hash a password using bcrypt."""
|
|
return pwd_context.hash(password)
|
|
|
|
|
|
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
|
"""Verify a password against its hash."""
|
|
return pwd_context.verify(plain_password, hashed_password)
|
|
|
|
|
|
def hash_token(token: str) -> str:
|
|
"""Hash a token (session or API key) using SHA256."""
|
|
return hashlib.sha256(token.encode()).hexdigest()
|
|
|
|
|
|
def generate_session_token() -> str:
|
|
"""Generate a cryptographically secure session token."""
|
|
return secrets.token_urlsafe(32)
|
|
|
|
|
|
def generate_api_key() -> str:
|
|
"""Generate a new API key with prefix.
|
|
|
|
Format: orch_<32 random bytes as hex>
|
|
"""
|
|
random_part = secrets.token_hex(32)
|
|
return f"{API_KEY_PREFIX}{random_part}"
|
|
|
|
|
|
class AuthService:
|
|
"""Authentication service for user management and session handling."""
|
|
|
|
def __init__(self, db: Session):
|
|
self.db = db
|
|
|
|
# --- User Operations ---
|
|
|
|
def create_user(
|
|
self,
|
|
username: str,
|
|
password: Optional[str] = None,
|
|
email: Optional[str] = None,
|
|
is_admin: bool = False,
|
|
must_change_password: bool = False,
|
|
) -> User:
|
|
"""Create a new user account."""
|
|
user = User(
|
|
username=username,
|
|
password_hash=hash_password(password) if password else None,
|
|
email=email,
|
|
is_admin=is_admin,
|
|
must_change_password=must_change_password,
|
|
)
|
|
self.db.add(user)
|
|
self.db.commit()
|
|
self.db.refresh(user)
|
|
return user
|
|
|
|
def get_user_by_username(self, username: str) -> Optional[User]:
|
|
"""Get a user by username."""
|
|
return self.db.query(User).filter(User.username == username).first()
|
|
|
|
def get_user_by_id(self, user_id: str) -> Optional[User]:
|
|
"""Get a user by ID."""
|
|
return self.db.query(User).filter(User.id == user_id).first()
|
|
|
|
def authenticate_user(self, username: str, password: str) -> Optional[User]:
|
|
"""Authenticate a user with username and password.
|
|
|
|
Returns the user if authentication succeeds, None otherwise.
|
|
Uses constant-time comparison to prevent timing-based user enumeration.
|
|
"""
|
|
user = self.get_user_by_username(username)
|
|
|
|
# Always perform password verification to prevent timing attacks
|
|
# Use a dummy hash if user doesn't exist
|
|
dummy_hash = "$2b$12$LQv3c1yqBWVHxkd0LHAkCOYz6TtxMQJqhN8/X4.VTtYA1vQ9S9sXa"
|
|
password_hash = user.password_hash if user and user.password_hash else dummy_hash
|
|
|
|
# Verify password (constant time even if user doesn't exist)
|
|
password_valid = verify_password(password, password_hash)
|
|
|
|
# Check all conditions
|
|
if not user:
|
|
return None
|
|
if not user.password_hash:
|
|
return None # OIDC-only user
|
|
if not user.is_active:
|
|
return None
|
|
if not password_valid:
|
|
return None
|
|
return user
|
|
|
|
def change_password(self, user: User, new_password: str) -> None:
|
|
"""Change a user's password and invalidate all existing sessions."""
|
|
validate_password_strength(new_password)
|
|
user.password_hash = hash_password(new_password)
|
|
user.must_change_password = False
|
|
self.db.commit()
|
|
# Invalidate all existing sessions for security
|
|
self.delete_user_sessions(user)
|
|
|
|
def update_last_login(self, user: User) -> None:
|
|
"""Update the user's last login timestamp."""
|
|
user.last_login = datetime.now(timezone.utc)
|
|
self.db.commit()
|
|
|
|
def list_users(self, include_inactive: bool = False) -> list[User]:
|
|
"""List all users."""
|
|
query = self.db.query(User)
|
|
if not include_inactive:
|
|
query = query.filter(User.is_active.is_(True))
|
|
return query.order_by(User.username).all()
|
|
|
|
def set_user_active(self, user: User, is_active: bool) -> None:
|
|
"""Enable or disable a user account."""
|
|
user.is_active = is_active
|
|
self.db.commit()
|
|
|
|
def set_user_admin(self, user: User, is_admin: bool) -> None:
|
|
"""Grant or revoke admin privileges."""
|
|
user.is_admin = is_admin
|
|
self.db.commit()
|
|
|
|
def reset_user_password(self, user: User, new_password: str) -> None:
|
|
"""Reset a user's password (admin action) and invalidate all sessions."""
|
|
validate_password_strength(new_password)
|
|
user.password_hash = hash_password(new_password)
|
|
user.must_change_password = True
|
|
self.db.commit()
|
|
# Invalidate all existing sessions for security
|
|
self.delete_user_sessions(user)
|
|
|
|
# --- Session Operations ---
|
|
|
|
def create_session(
|
|
self,
|
|
user: User,
|
|
user_agent: Optional[str] = None,
|
|
ip_address: Optional[str] = None,
|
|
) -> tuple[UserSession, str]:
|
|
"""Create a new session for a user.
|
|
|
|
Returns a tuple of (session, token) where token is the plaintext
|
|
token that should be sent to the client. The token is only returned
|
|
once and should be stored securely.
|
|
"""
|
|
token = generate_session_token()
|
|
token_hash = hash_token(token)
|
|
|
|
session = UserSession(
|
|
user_id=user.id,
|
|
token_hash=token_hash,
|
|
expires_at=datetime.now(timezone.utc)
|
|
+ timedelta(hours=SESSION_DURATION_HOURS),
|
|
user_agent=user_agent,
|
|
ip_address=ip_address,
|
|
)
|
|
self.db.add(session)
|
|
self.db.commit()
|
|
self.db.refresh(session)
|
|
|
|
return session, token
|
|
|
|
def get_session_by_token(self, token: str) -> Optional[UserSession]:
|
|
"""Get a session by its token.
|
|
|
|
Returns None if the session doesn't exist or has expired.
|
|
"""
|
|
token_hash = hash_token(token)
|
|
session = (
|
|
self.db.query(UserSession)
|
|
.filter(UserSession.token_hash == token_hash)
|
|
.first()
|
|
)
|
|
|
|
if not session:
|
|
return None
|
|
|
|
if session.expires_at < datetime.now(timezone.utc):
|
|
# Session has expired, delete it
|
|
self.db.delete(session)
|
|
self.db.commit()
|
|
return None
|
|
|
|
# Update last accessed time
|
|
session.last_accessed = datetime.now(timezone.utc)
|
|
self.db.commit()
|
|
|
|
return session
|
|
|
|
def delete_session(self, session: UserSession) -> None:
|
|
"""Delete a session (logout)."""
|
|
self.db.delete(session)
|
|
self.db.commit()
|
|
|
|
def delete_user_sessions(self, user: User) -> int:
|
|
"""Delete all sessions for a user. Returns count of deleted sessions."""
|
|
count = (
|
|
self.db.query(UserSession).filter(UserSession.user_id == user.id).delete()
|
|
)
|
|
self.db.commit()
|
|
return count
|
|
|
|
def cleanup_expired_sessions(self) -> int:
|
|
"""Delete all expired sessions. Returns count of deleted sessions."""
|
|
count = (
|
|
self.db.query(UserSession)
|
|
.filter(UserSession.expires_at < datetime.now(timezone.utc))
|
|
.delete()
|
|
)
|
|
self.db.commit()
|
|
return count
|
|
|
|
# --- API Key Operations ---
|
|
|
|
def create_api_key(
|
|
self,
|
|
user: User,
|
|
name: str,
|
|
description: Optional[str] = None,
|
|
scopes: Optional[list[str]] = None,
|
|
expires_at: Optional[datetime] = None,
|
|
) -> tuple[APIKey, str]:
|
|
"""Create a new API key for a user.
|
|
|
|
Returns a tuple of (api_key, key) where key is the plaintext
|
|
API key that should be sent to the client. The key is only returned
|
|
once and should be stored securely by the user.
|
|
"""
|
|
key = generate_api_key()
|
|
key_hash = hash_token(key)
|
|
|
|
api_key = APIKey(
|
|
key_hash=key_hash,
|
|
name=name,
|
|
user_id=user.username, # Legacy field
|
|
owner_id=user.id,
|
|
description=description,
|
|
scopes=scopes or ["read", "write"],
|
|
expires_at=expires_at,
|
|
)
|
|
self.db.add(api_key)
|
|
self.db.commit()
|
|
self.db.refresh(api_key)
|
|
|
|
return api_key, key
|
|
|
|
def get_api_key_by_key(self, key: str) -> Optional[APIKey]:
|
|
"""Get an API key by its plaintext key.
|
|
|
|
Returns None if the key doesn't exist or has expired.
|
|
"""
|
|
if not key.startswith(API_KEY_PREFIX):
|
|
return None
|
|
|
|
key_hash = hash_token(key)
|
|
api_key = self.db.query(APIKey).filter(APIKey.key_hash == key_hash).first()
|
|
|
|
if not api_key:
|
|
return None
|
|
|
|
# Check expiration
|
|
if api_key.expires_at and api_key.expires_at < datetime.now(timezone.utc):
|
|
return None
|
|
|
|
# Update last used time
|
|
api_key.last_used = datetime.now(timezone.utc)
|
|
self.db.commit()
|
|
|
|
return api_key
|
|
|
|
def get_api_key_by_id(self, key_id: str) -> Optional[APIKey]:
|
|
"""Get an API key by its ID."""
|
|
return self.db.query(APIKey).filter(APIKey.id == key_id).first()
|
|
|
|
def list_user_api_keys(self, user: User) -> list[APIKey]:
|
|
"""List all API keys for a user."""
|
|
return (
|
|
self.db.query(APIKey)
|
|
.filter(APIKey.owner_id == user.id)
|
|
.order_by(APIKey.created_at.desc())
|
|
.all()
|
|
)
|
|
|
|
def delete_api_key(self, api_key: APIKey) -> None:
|
|
"""Delete an API key."""
|
|
self.db.delete(api_key)
|
|
self.db.commit()
|
|
|
|
def get_user_from_api_key(self, key: str) -> Optional[User]:
|
|
"""Get the user associated with an API key.
|
|
|
|
Returns None if the key is invalid or the user is inactive.
|
|
"""
|
|
api_key = self.get_api_key_by_key(key)
|
|
if not api_key:
|
|
return None
|
|
|
|
if not api_key.owner_id:
|
|
return None
|
|
|
|
user = self.db.query(User).filter(User.id == api_key.owner_id).first()
|
|
if not user or not user.is_active:
|
|
return None
|
|
|
|
return user
|
|
|
|
|
|
def create_default_admin(db: Session) -> Optional[User]:
|
|
"""Create the default admin user if no users exist.
|
|
|
|
Returns the created user, or None if users already exist.
|
|
"""
|
|
# Check if any users exist
|
|
user_count = db.query(User).count()
|
|
if user_count > 0:
|
|
return None
|
|
|
|
# Create default admin
|
|
auth_service = AuthService(db)
|
|
admin = auth_service.create_user(
|
|
username="admin",
|
|
password="changeme123",
|
|
is_admin=True,
|
|
must_change_password=True,
|
|
)
|
|
|
|
return admin
|
|
|
|
|
|
# --- JWT Validation ---
|
|
|
|
|
|
def validate_jwt_token(token: str) -> Optional[dict]:
|
|
"""Validate a JWT token and return the decoded payload.
|
|
|
|
Returns None if validation fails or JWT is not configured.
|
|
Uses python-jose for JWT operations.
|
|
"""
|
|
settings = get_settings()
|
|
|
|
if not settings.jwt_enabled:
|
|
return None
|
|
|
|
try:
|
|
from jose import jwt, JWTError, ExpiredSignatureError
|
|
from jose.exceptions import JWTClaimsError
|
|
except ImportError:
|
|
logger.warning("python-jose not installed, JWT authentication disabled")
|
|
return None
|
|
|
|
try:
|
|
# Build decode options
|
|
decode_options = {}
|
|
|
|
# Set up key for validation
|
|
if settings.jwt_algorithm.startswith("RS"):
|
|
# RS256/RS384/RS512 - use JWKS
|
|
if not settings.jwt_jwks_url:
|
|
logger.error("JWT JWKS URL not configured for RSA algorithm")
|
|
return None
|
|
|
|
try:
|
|
import httpx
|
|
|
|
# Fetch JWKS from the URL
|
|
response = httpx.get(settings.jwt_jwks_url, timeout=10.0)
|
|
response.raise_for_status()
|
|
jwks = response.json()
|
|
|
|
# Get the key ID from the token header
|
|
unverified_header = jwt.get_unverified_header(token)
|
|
kid = unverified_header.get("kid")
|
|
|
|
# Find the matching key
|
|
rsa_key = None
|
|
for key in jwks.get("keys", []):
|
|
if key.get("kid") == kid:
|
|
rsa_key = key
|
|
break
|
|
|
|
if not rsa_key:
|
|
logger.error(f"No matching key found in JWKS for kid: {kid}")
|
|
return None
|
|
|
|
key = rsa_key
|
|
except Exception as e:
|
|
logger.error(f"Failed to get signing key from JWKS: {e}")
|
|
return None
|
|
else:
|
|
# HS256/HS384/HS512 - use secret
|
|
if not settings.jwt_secret:
|
|
logger.error("JWT secret not configured for HMAC algorithm")
|
|
return None
|
|
key = settings.jwt_secret
|
|
|
|
# Build decode kwargs
|
|
decode_kwargs = {
|
|
"algorithms": [settings.jwt_algorithm],
|
|
"options": decode_options,
|
|
}
|
|
|
|
# Add issuer validation if configured
|
|
if settings.jwt_issuer:
|
|
decode_kwargs["issuer"] = settings.jwt_issuer
|
|
|
|
# Add audience validation if configured
|
|
if settings.jwt_audience:
|
|
decode_kwargs["audience"] = settings.jwt_audience
|
|
|
|
# Decode and validate the token
|
|
payload = jwt.decode(token, key, **decode_kwargs)
|
|
return payload
|
|
|
|
except ExpiredSignatureError:
|
|
logger.debug("JWT token expired")
|
|
return None
|
|
except JWTClaimsError as e:
|
|
logger.debug(f"JWT claims error: {e}")
|
|
return None
|
|
except JWTError as e:
|
|
logger.debug(f"Invalid JWT token: {e}")
|
|
return None
|
|
except Exception as e:
|
|
logger.error(f"JWT validation error: {e}")
|
|
return None
|
|
|
|
|
|
def get_or_create_user_from_jwt(db: Session, payload: dict) -> Optional[User]:
|
|
"""Get or create a user from JWT payload.
|
|
|
|
Uses the configured username claim to extract the username.
|
|
Creates a new user if one doesn't exist (for SSO auto-provisioning).
|
|
"""
|
|
settings = get_settings()
|
|
username = payload.get(settings.jwt_username_claim)
|
|
|
|
if not username:
|
|
logger.warning(f"JWT missing username claim: {settings.jwt_username_claim}")
|
|
return None
|
|
|
|
# Sanitize username (remove domain from email if needed)
|
|
if "@" in username and settings.jwt_username_claim == "email":
|
|
# Keep full email as username for email-based auth
|
|
pass
|
|
|
|
auth_service = AuthService(db)
|
|
user = auth_service.get_user_by_username(username)
|
|
|
|
if user:
|
|
if not user.is_active:
|
|
logger.debug(f"JWT user {username} is inactive")
|
|
return None
|
|
return user
|
|
|
|
# Auto-provision user from JWT
|
|
logger.info(f"Auto-provisioning user from JWT: {username}")
|
|
try:
|
|
user = auth_service.create_user(
|
|
username=username,
|
|
password=None, # No password for SSO users
|
|
email=payload.get("email"),
|
|
is_admin=False,
|
|
must_change_password=False,
|
|
)
|
|
return user
|
|
except Exception as e:
|
|
logger.error(f"Failed to auto-provision JWT user: {e}")
|
|
return None
|
|
|
|
|
|
# --- FastAPI Dependencies ---
|
|
|
|
from fastapi import Depends, HTTPException, status, Cookie, Header
|
|
from .database import get_db
|
|
|
|
# Cookie name for session token
|
|
SESSION_COOKIE_NAME = "orchard_session"
|
|
|
|
|
|
def get_current_user_optional(
|
|
db: Session = Depends(get_db),
|
|
session_token: Optional[str] = Cookie(None, alias=SESSION_COOKIE_NAME),
|
|
authorization: Optional[str] = Header(None),
|
|
) -> Optional[User]:
|
|
"""Get the current user from session cookie, API key, or JWT token.
|
|
|
|
Returns None if no valid authentication is provided.
|
|
Does not raise an exception for unauthenticated requests.
|
|
|
|
Authentication methods are tried in order:
|
|
1. Session cookie (web UI)
|
|
2. API key (Bearer token starting with 'orch_')
|
|
3. JWT token (Bearer token that's a valid JWT)
|
|
"""
|
|
auth_service = AuthService(db)
|
|
|
|
# First try session cookie (web UI)
|
|
if session_token:
|
|
session = auth_service.get_session_by_token(session_token)
|
|
if session:
|
|
user = auth_service.get_user_by_id(str(session.user_id))
|
|
if user and user.is_active:
|
|
return user
|
|
|
|
# Then try Bearer token (API key or JWT)
|
|
if authorization and authorization.startswith("Bearer "):
|
|
token = authorization[7:] # Remove "Bearer " prefix
|
|
|
|
# Check if it's an API key (starts with orch_)
|
|
if token.startswith(API_KEY_PREFIX):
|
|
user = auth_service.get_user_from_api_key(token)
|
|
if user:
|
|
return user
|
|
else:
|
|
# Try JWT validation
|
|
settings = get_settings()
|
|
if settings.jwt_enabled:
|
|
payload = validate_jwt_token(token)
|
|
if payload:
|
|
user = get_or_create_user_from_jwt(db, payload)
|
|
if user:
|
|
return user
|
|
|
|
return None
|
|
|
|
|
|
def get_current_user(
|
|
user: Optional[User] = Depends(get_current_user_optional),
|
|
) -> User:
|
|
"""Get the current authenticated user.
|
|
|
|
Raises HTTPException 401 if not authenticated.
|
|
"""
|
|
if not user:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Not authenticated",
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|
)
|
|
return user
|
|
|
|
|
|
def require_admin(
|
|
user: User = Depends(get_current_user),
|
|
) -> User:
|
|
"""Require the current user to be an admin.
|
|
|
|
Raises HTTPException 403 if user is not an admin.
|
|
"""
|
|
if not user.is_admin:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail="Admin privileges required",
|
|
)
|
|
return user
|
|
|
|
|
|
def get_auth_service(db: Session = Depends(get_db)) -> AuthService:
|
|
"""Get an AuthService instance."""
|
|
return AuthService(db)
|
|
|
|
|
|
# --- Authorization ---
|
|
|
|
# Access levels in order of increasing privilege
|
|
ACCESS_LEVELS = ["read", "write", "admin"]
|
|
|
|
|
|
def get_access_level_rank(level: str) -> int:
|
|
"""Get numeric rank for access level comparison."""
|
|
try:
|
|
return ACCESS_LEVELS.index(level)
|
|
except ValueError:
|
|
return -1
|
|
|
|
|
|
def has_sufficient_access(user_level: str, required_level: str) -> bool:
|
|
"""Check if user_level is sufficient for required_level.
|
|
|
|
Access levels are hierarchical: admin > write > read
|
|
"""
|
|
return get_access_level_rank(user_level) >= get_access_level_rank(required_level)
|
|
|
|
|
|
class AuthorizationService:
|
|
"""Service for checking project-level authorization."""
|
|
|
|
def __init__(self, db: Session):
|
|
self.db = db
|
|
|
|
def get_user_access_level(
|
|
self, project_id: str, user: Optional[User]
|
|
) -> Optional[str]:
|
|
"""Get the user's access level for a project.
|
|
|
|
Returns the highest access level the user has, or None if no access.
|
|
Checks in order:
|
|
1. System admin - gets admin access to all projects
|
|
2. Project owner (created_by) - gets admin access
|
|
3. Explicit permission in access_permissions table
|
|
"""
|
|
from .models import Project, AccessPermission
|
|
|
|
# Get the project
|
|
project = self.db.query(Project).filter(Project.id == project_id).first()
|
|
if not project:
|
|
return None
|
|
|
|
# Anonymous users only get access to public projects
|
|
if not user:
|
|
return "read" if project.is_public else None
|
|
|
|
# System admins get admin access everywhere
|
|
if user.is_admin:
|
|
return "admin"
|
|
|
|
# Project owner gets admin access
|
|
if project.created_by == user.username:
|
|
return "admin"
|
|
|
|
# Check explicit permissions
|
|
permission = (
|
|
self.db.query(AccessPermission)
|
|
.filter(
|
|
AccessPermission.project_id == project_id,
|
|
AccessPermission.user_id == user.username,
|
|
)
|
|
.first()
|
|
)
|
|
|
|
if permission:
|
|
# Check expiration
|
|
if permission.expires_at and permission.expires_at < datetime.now(timezone.utc):
|
|
return "read" if project.is_public else None
|
|
return permission.level
|
|
|
|
# Fall back to public access
|
|
return "read" if project.is_public else None
|
|
|
|
def check_access(
|
|
self,
|
|
project_id: str,
|
|
user: Optional[User],
|
|
required_level: str,
|
|
) -> bool:
|
|
"""Check if user has required access level for project."""
|
|
user_level = self.get_user_access_level(project_id, user)
|
|
if not user_level:
|
|
return False
|
|
return has_sufficient_access(user_level, required_level)
|
|
|
|
def grant_access(
|
|
self,
|
|
project_id: str,
|
|
username: str,
|
|
level: str,
|
|
expires_at: Optional[datetime] = None,
|
|
) -> "AccessPermission":
|
|
"""Grant access to a user for a project."""
|
|
from .models import AccessPermission
|
|
|
|
# Check if permission already exists
|
|
existing = (
|
|
self.db.query(AccessPermission)
|
|
.filter(
|
|
AccessPermission.project_id == project_id,
|
|
AccessPermission.user_id == username,
|
|
)
|
|
.first()
|
|
)
|
|
|
|
if existing:
|
|
existing.level = level
|
|
existing.expires_at = expires_at
|
|
self.db.commit()
|
|
return existing
|
|
|
|
permission = AccessPermission(
|
|
project_id=project_id,
|
|
user_id=username,
|
|
level=level,
|
|
expires_at=expires_at,
|
|
)
|
|
self.db.add(permission)
|
|
self.db.commit()
|
|
self.db.refresh(permission)
|
|
return permission
|
|
|
|
def revoke_access(self, project_id: str, username: str) -> bool:
|
|
"""Revoke a user's access to a project. Returns True if deleted."""
|
|
from .models import AccessPermission
|
|
|
|
count = (
|
|
self.db.query(AccessPermission)
|
|
.filter(
|
|
AccessPermission.project_id == project_id,
|
|
AccessPermission.user_id == username,
|
|
)
|
|
.delete()
|
|
)
|
|
self.db.commit()
|
|
return count > 0
|
|
|
|
def list_project_permissions(self, project_id: str) -> list:
|
|
"""List all permissions for a project."""
|
|
from .models import AccessPermission
|
|
|
|
return (
|
|
self.db.query(AccessPermission)
|
|
.filter(AccessPermission.project_id == project_id)
|
|
.all()
|
|
)
|
|
|
|
|
|
def get_authorization_service(db: Session = Depends(get_db)) -> AuthorizationService:
|
|
"""Get an AuthorizationService instance."""
|
|
return AuthorizationService(db)
|
|
|
|
|
|
class ProjectAccessChecker:
|
|
"""Dependency for checking project access in route handlers."""
|
|
|
|
def __init__(self, required_level: str = "read"):
|
|
self.required_level = required_level
|
|
|
|
def __call__(
|
|
self,
|
|
project: str,
|
|
db: Session = Depends(get_db),
|
|
current_user: Optional[User] = Depends(get_current_user_optional),
|
|
) -> User:
|
|
"""Check if user has required access to project.
|
|
|
|
Raises 404 if project not found, 403 if insufficient access.
|
|
Returns the current user (or None for public read access).
|
|
"""
|
|
from .models import Project
|
|
|
|
# Find project by name
|
|
proj = db.query(Project).filter(Project.name == project).first()
|
|
if not proj:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail=f"Project '{project}' not found",
|
|
)
|
|
|
|
auth_service = AuthorizationService(db)
|
|
|
|
if not auth_service.check_access(str(proj.id), current_user, self.required_level):
|
|
if not current_user:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Authentication required for private project",
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|
)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail=f"Insufficient permissions. Required: {self.required_level}",
|
|
)
|
|
|
|
return current_user
|
|
|
|
|
|
# Pre-configured access checkers for common use cases
|
|
require_project_read = ProjectAccessChecker("read")
|
|
require_project_write = ProjectAccessChecker("write")
|
|
require_project_admin = ProjectAccessChecker("admin")
|
|
|
|
|
|
def check_project_access(
|
|
db: Session,
|
|
project_name: str,
|
|
user: Optional[User],
|
|
required_level: str = "read",
|
|
) -> "Project":
|
|
"""Check if user has required access to project.
|
|
|
|
This is a helper function for use in route handlers.
|
|
|
|
Args:
|
|
db: Database session
|
|
project_name: Name of the project
|
|
user: Current user (can be None for anonymous)
|
|
required_level: Required access level (read, write, admin)
|
|
|
|
Returns:
|
|
The Project object if access is granted
|
|
|
|
Raises:
|
|
HTTPException 404: Project not found
|
|
HTTPException 401: Authentication required for private project
|
|
HTTPException 403: Insufficient permissions
|
|
"""
|
|
from .models import Project
|
|
|
|
# Find project by name
|
|
project = db.query(Project).filter(Project.name == project_name).first()
|
|
if not project:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail=f"Project '{project_name}' not found",
|
|
)
|
|
|
|
auth_service = AuthorizationService(db)
|
|
|
|
if not auth_service.check_access(str(project.id), user, required_level):
|
|
if not user:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Authentication required for private project",
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|
)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail=f"Insufficient permissions. Required: {required_level}",
|
|
)
|
|
|
|
return project
|
|
|
|
|
|
# --- OIDC Configuration Service ---
|
|
|
|
|
|
class OIDCConfig:
|
|
"""OIDC configuration data class."""
|
|
|
|
def __init__(
|
|
self,
|
|
enabled: bool = False,
|
|
issuer_url: str = "",
|
|
client_id: str = "",
|
|
client_secret: str = "",
|
|
scopes: list[str] = None,
|
|
auto_create_users: bool = True,
|
|
admin_group: str = "", # Group/role that grants admin access
|
|
):
|
|
self.enabled = enabled
|
|
self.issuer_url = issuer_url.rstrip("/") if issuer_url else ""
|
|
self.client_id = client_id
|
|
self.client_secret = client_secret
|
|
self.scopes = scopes or ["openid", "profile", "email"]
|
|
self.auto_create_users = auto_create_users
|
|
self.admin_group = admin_group
|
|
|
|
@property
|
|
def discovery_url(self) -> str:
|
|
"""Get the OIDC discovery URL."""
|
|
if not self.issuer_url:
|
|
return ""
|
|
return f"{self.issuer_url}/.well-known/openid-configuration"
|
|
|
|
def to_dict(self) -> dict:
|
|
"""Convert to dictionary for storage."""
|
|
return {
|
|
"enabled": self.enabled,
|
|
"issuer_url": self.issuer_url,
|
|
"client_id": self.client_id,
|
|
"client_secret": self.client_secret,
|
|
"scopes": self.scopes,
|
|
"auto_create_users": self.auto_create_users,
|
|
"admin_group": self.admin_group,
|
|
}
|
|
|
|
@classmethod
|
|
def from_dict(cls, data: dict) -> "OIDCConfig":
|
|
"""Create from dictionary."""
|
|
return cls(
|
|
enabled=data.get("enabled", False),
|
|
issuer_url=data.get("issuer_url", ""),
|
|
client_id=data.get("client_id", ""),
|
|
client_secret=data.get("client_secret", ""),
|
|
scopes=data.get("scopes", ["openid", "profile", "email"]),
|
|
auto_create_users=data.get("auto_create_users", True),
|
|
admin_group=data.get("admin_group", ""),
|
|
)
|
|
|
|
|
|
class OIDCConfigService:
|
|
"""Service for managing OIDC configuration."""
|
|
|
|
OIDC_CONFIG_KEY = "oidc_config"
|
|
|
|
def __init__(self, db: Session):
|
|
self.db = db
|
|
|
|
def get_config(self) -> OIDCConfig:
|
|
"""Get the current OIDC configuration."""
|
|
from .models import AuthSettings
|
|
import json
|
|
|
|
setting = (
|
|
self.db.query(AuthSettings)
|
|
.filter(AuthSettings.key == self.OIDC_CONFIG_KEY)
|
|
.first()
|
|
)
|
|
|
|
if not setting:
|
|
return OIDCConfig()
|
|
|
|
try:
|
|
data = json.loads(setting.value)
|
|
return OIDCConfig.from_dict(data)
|
|
except (json.JSONDecodeError, KeyError):
|
|
return OIDCConfig()
|
|
|
|
def save_config(self, config: OIDCConfig) -> None:
|
|
"""Save OIDC configuration."""
|
|
from .models import AuthSettings
|
|
import json
|
|
|
|
setting = (
|
|
self.db.query(AuthSettings)
|
|
.filter(AuthSettings.key == self.OIDC_CONFIG_KEY)
|
|
.first()
|
|
)
|
|
|
|
if setting:
|
|
setting.value = json.dumps(config.to_dict())
|
|
setting.updated_at = datetime.now(timezone.utc)
|
|
else:
|
|
setting = AuthSettings(
|
|
key=self.OIDC_CONFIG_KEY,
|
|
value=json.dumps(config.to_dict()),
|
|
)
|
|
self.db.add(setting)
|
|
|
|
self.db.commit()
|
|
|
|
def is_enabled(self) -> bool:
|
|
"""Check if OIDC is enabled."""
|
|
config = self.get_config()
|
|
return config.enabled and bool(config.issuer_url) and bool(config.client_id)
|
|
|
|
|
|
def get_oidc_config_service(db: Session = Depends(get_db)) -> OIDCConfigService:
|
|
"""Get an OIDCConfigService instance."""
|
|
return OIDCConfigService(db)
|
|
|
|
|
|
# --- OIDC Authentication Flow ---
|
|
|
|
|
|
class OIDCService:
|
|
"""Service for OIDC authentication flow."""
|
|
|
|
def __init__(self, db: Session, config: OIDCConfig):
|
|
self.db = db
|
|
self.config = config
|
|
self._discovery_doc: Optional[dict] = None
|
|
|
|
def get_discovery_document(self) -> Optional[dict]:
|
|
"""Fetch and cache the OIDC discovery document."""
|
|
if self._discovery_doc:
|
|
return self._discovery_doc
|
|
|
|
if not self.config.discovery_url:
|
|
return None
|
|
|
|
try:
|
|
import httpx
|
|
|
|
response = httpx.get(self.config.discovery_url, timeout=10.0)
|
|
response.raise_for_status()
|
|
self._discovery_doc = response.json()
|
|
return self._discovery_doc
|
|
except Exception as e:
|
|
logger.error(f"Failed to fetch OIDC discovery document: {e}")
|
|
return None
|
|
|
|
def get_authorization_url(self, redirect_uri: str, state: str) -> Optional[str]:
|
|
"""Generate the OIDC authorization URL."""
|
|
discovery = self.get_discovery_document()
|
|
if not discovery:
|
|
return None
|
|
|
|
auth_endpoint = discovery.get("authorization_endpoint")
|
|
if not auth_endpoint:
|
|
logger.error("No authorization_endpoint in discovery document")
|
|
return None
|
|
|
|
import urllib.parse
|
|
|
|
params = {
|
|
"client_id": self.config.client_id,
|
|
"response_type": "code",
|
|
"scope": " ".join(self.config.scopes),
|
|
"redirect_uri": redirect_uri,
|
|
"state": state,
|
|
}
|
|
|
|
return f"{auth_endpoint}?{urllib.parse.urlencode(params)}"
|
|
|
|
def exchange_code_for_tokens(
|
|
self, code: str, redirect_uri: str
|
|
) -> Optional[dict]:
|
|
"""Exchange authorization code for tokens."""
|
|
discovery = self.get_discovery_document()
|
|
if not discovery:
|
|
return None
|
|
|
|
token_endpoint = discovery.get("token_endpoint")
|
|
if not token_endpoint:
|
|
logger.error("No token_endpoint in discovery document")
|
|
return None
|
|
|
|
try:
|
|
import httpx
|
|
|
|
response = httpx.post(
|
|
token_endpoint,
|
|
data={
|
|
"grant_type": "authorization_code",
|
|
"code": code,
|
|
"redirect_uri": redirect_uri,
|
|
"client_id": self.config.client_id,
|
|
"client_secret": self.config.client_secret,
|
|
},
|
|
timeout=10.0,
|
|
)
|
|
response.raise_for_status()
|
|
return response.json()
|
|
except Exception as e:
|
|
logger.error(f"Failed to exchange code for tokens: {e}")
|
|
return None
|
|
|
|
def validate_id_token(self, id_token: str) -> Optional[dict]:
|
|
"""Validate and decode the ID token."""
|
|
discovery = self.get_discovery_document()
|
|
if not discovery:
|
|
return None
|
|
|
|
try:
|
|
from jose import jwt, JWTError
|
|
import httpx
|
|
|
|
# Get JWKS
|
|
jwks_uri = discovery.get("jwks_uri")
|
|
if not jwks_uri:
|
|
logger.error("No jwks_uri in discovery document")
|
|
return None
|
|
|
|
response = httpx.get(jwks_uri, timeout=10.0)
|
|
response.raise_for_status()
|
|
jwks = response.json()
|
|
|
|
# Get the key ID from the token header
|
|
unverified_header = jwt.get_unverified_header(id_token)
|
|
kid = unverified_header.get("kid")
|
|
|
|
# Find the matching key
|
|
rsa_key = None
|
|
for key in jwks.get("keys", []):
|
|
if key.get("kid") == kid:
|
|
rsa_key = key
|
|
break
|
|
|
|
if not rsa_key:
|
|
logger.error(f"No matching key found in JWKS for kid: {kid}")
|
|
return None
|
|
|
|
# Decode and validate the token
|
|
payload = jwt.decode(
|
|
id_token,
|
|
rsa_key,
|
|
algorithms=["RS256"],
|
|
audience=self.config.client_id,
|
|
issuer=self.config.issuer_url,
|
|
)
|
|
|
|
return payload
|
|
|
|
except JWTError as e:
|
|
logger.error(f"ID token validation failed: {e}")
|
|
return None
|
|
except Exception as e:
|
|
logger.error(f"Error validating ID token: {e}")
|
|
return None
|
|
|
|
def get_or_create_user(self, id_token_claims: dict) -> Optional[User]:
|
|
"""Get or create a user from ID token claims."""
|
|
# Extract user info from claims
|
|
subject = id_token_claims.get("sub")
|
|
email = id_token_claims.get("email")
|
|
name = id_token_claims.get("name") or id_token_claims.get("preferred_username")
|
|
|
|
if not subject:
|
|
logger.error("No 'sub' claim in ID token")
|
|
return None
|
|
|
|
# Try to find existing user by OIDC subject
|
|
user = (
|
|
self.db.query(User)
|
|
.filter(
|
|
User.oidc_subject == subject,
|
|
User.oidc_issuer == self.config.issuer_url,
|
|
)
|
|
.first()
|
|
)
|
|
|
|
if user:
|
|
# Update last login
|
|
user.last_login = datetime.now(timezone.utc)
|
|
self.db.commit()
|
|
return user
|
|
|
|
# Try to find by email and link accounts
|
|
if email:
|
|
user = self.db.query(User).filter(User.email == email).first()
|
|
if user:
|
|
# Link OIDC identity to existing user
|
|
user.oidc_subject = subject
|
|
user.oidc_issuer = self.config.issuer_url
|
|
user.last_login = datetime.now(timezone.utc)
|
|
self.db.commit()
|
|
logger.info(f"Linked OIDC identity to existing user: {user.username}")
|
|
return user
|
|
|
|
# Create new user if auto-creation is enabled
|
|
if not self.config.auto_create_users:
|
|
logger.warning(f"Auto-creation disabled, rejecting new OIDC user: {subject}")
|
|
return None
|
|
|
|
# Determine username (use email prefix or subject)
|
|
username = email.split("@")[0] if email else subject
|
|
|
|
# Check for username collision
|
|
existing = self.db.query(User).filter(User.username == username).first()
|
|
if existing:
|
|
# Append part of subject to make unique
|
|
username = f"{username}_{subject[:8]}"
|
|
|
|
# Check if user should be admin based on groups/roles
|
|
is_admin = False
|
|
if self.config.admin_group:
|
|
groups = id_token_claims.get("groups", [])
|
|
roles = id_token_claims.get("roles", [])
|
|
is_admin = (
|
|
self.config.admin_group in groups
|
|
or self.config.admin_group in roles
|
|
)
|
|
|
|
# Create the user
|
|
user = User(
|
|
username=username,
|
|
email=email,
|
|
password_hash=None, # OIDC users don't have passwords
|
|
oidc_subject=subject,
|
|
oidc_issuer=self.config.issuer_url,
|
|
is_admin=is_admin,
|
|
must_change_password=False,
|
|
)
|
|
self.db.add(user)
|
|
self.db.commit()
|
|
self.db.refresh(user)
|
|
|
|
logger.info(f"Created new OIDC user: {username} (admin={is_admin})")
|
|
return user
|