"""Authentication service for Orchard. Handles password hashing, session management, and API key operations. """ import hashlib import secrets from datetime import datetime, timedelta from typing import Optional from passlib.context import CryptContext from sqlalchemy.orm import Session from .models import User, Session as UserSession, APIKey # Password hashing context (bcrypt with cost factor 12) pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") # API key prefix API_KEY_PREFIX = "orch_" # Session duration (24 hours default) SESSION_DURATION_HOURS = 24 def hash_password(password: str) -> str: """Hash a password using bcrypt.""" return pwd_context.hash(password) def verify_password(plain_password: str, hashed_password: str) -> bool: """Verify a password against its hash.""" return pwd_context.verify(plain_password, hashed_password) def hash_token(token: str) -> str: """Hash a token (session or API key) using SHA256.""" return hashlib.sha256(token.encode()).hexdigest() def generate_session_token() -> str: """Generate a cryptographically secure session token.""" return secrets.token_urlsafe(32) def generate_api_key() -> str: """Generate a new API key with prefix. Format: orch_<32 random bytes as hex> """ random_part = secrets.token_hex(32) return f"{API_KEY_PREFIX}{random_part}" class AuthService: """Authentication service for user management and session handling.""" def __init__(self, db: Session): self.db = db # --- User Operations --- def create_user( self, username: str, password: Optional[str] = None, email: Optional[str] = None, is_admin: bool = False, must_change_password: bool = False, ) -> User: """Create a new user account.""" user = User( username=username, password_hash=hash_password(password) if password else None, email=email, is_admin=is_admin, must_change_password=must_change_password, ) self.db.add(user) self.db.commit() self.db.refresh(user) return user def get_user_by_username(self, username: str) -> Optional[User]: """Get a user by username.""" return self.db.query(User).filter(User.username == username).first() def get_user_by_id(self, user_id: str) -> Optional[User]: """Get a user by ID.""" return self.db.query(User).filter(User.id == user_id).first() def authenticate_user(self, username: str, password: str) -> Optional[User]: """Authenticate a user with username and password. Returns the user if authentication succeeds, None otherwise. """ user = self.get_user_by_username(username) if not user: return None if not user.password_hash: return None # OIDC-only user if not user.is_active: return None if not verify_password(password, user.password_hash): return None return user def change_password(self, user: User, new_password: str) -> None: """Change a user's password.""" user.password_hash = hash_password(new_password) user.must_change_password = False self.db.commit() def update_last_login(self, user: User) -> None: """Update the user's last login timestamp.""" user.last_login = datetime.utcnow() self.db.commit() def list_users(self, include_inactive: bool = False) -> list[User]: """List all users.""" query = self.db.query(User) if not include_inactive: query = query.filter(User.is_active == True) return query.order_by(User.username).all() def set_user_active(self, user: User, is_active: bool) -> None: """Enable or disable a user account.""" user.is_active = is_active self.db.commit() def set_user_admin(self, user: User, is_admin: bool) -> None: """Grant or revoke admin privileges.""" user.is_admin = is_admin self.db.commit() def reset_user_password(self, user: User, new_password: str) -> None: """Reset a user's password (admin action).""" user.password_hash = hash_password(new_password) user.must_change_password = True self.db.commit() # --- Session Operations --- def create_session( self, user: User, user_agent: Optional[str] = None, ip_address: Optional[str] = None, ) -> tuple[UserSession, str]: """Create a new session for a user. Returns a tuple of (session, token) where token is the plaintext token that should be sent to the client. The token is only returned once and should be stored securely. """ token = generate_session_token() token_hash = hash_token(token) session = UserSession( user_id=user.id, token_hash=token_hash, expires_at=datetime.utcnow() + timedelta(hours=SESSION_DURATION_HOURS), user_agent=user_agent, ip_address=ip_address, ) self.db.add(session) self.db.commit() self.db.refresh(session) return session, token def get_session_by_token(self, token: str) -> Optional[UserSession]: """Get a session by its token. Returns None if the session doesn't exist or has expired. """ token_hash = hash_token(token) session = ( self.db.query(UserSession) .filter(UserSession.token_hash == token_hash) .first() ) if not session: return None if session.expires_at < datetime.utcnow(): # Session has expired, delete it self.db.delete(session) self.db.commit() return None # Update last accessed time session.last_accessed = datetime.utcnow() self.db.commit() return session def delete_session(self, session: UserSession) -> None: """Delete a session (logout).""" self.db.delete(session) self.db.commit() def delete_user_sessions(self, user: User) -> int: """Delete all sessions for a user. Returns count of deleted sessions.""" count = ( self.db.query(UserSession).filter(UserSession.user_id == user.id).delete() ) self.db.commit() return count def cleanup_expired_sessions(self) -> int: """Delete all expired sessions. Returns count of deleted sessions.""" count = ( self.db.query(UserSession) .filter(UserSession.expires_at < datetime.utcnow()) .delete() ) self.db.commit() return count # --- API Key Operations --- def create_api_key( self, user: User, name: str, description: Optional[str] = None, scopes: Optional[list[str]] = None, expires_at: Optional[datetime] = None, ) -> tuple[APIKey, str]: """Create a new API key for a user. Returns a tuple of (api_key, key) where key is the plaintext API key that should be sent to the client. The key is only returned once and should be stored securely by the user. """ key = generate_api_key() key_hash = hash_token(key) api_key = APIKey( key_hash=key_hash, name=name, user_id=user.username, # Legacy field owner_id=user.id, description=description, scopes=scopes or ["read", "write"], expires_at=expires_at, ) self.db.add(api_key) self.db.commit() self.db.refresh(api_key) return api_key, key def get_api_key_by_key(self, key: str) -> Optional[APIKey]: """Get an API key by its plaintext key. Returns None if the key doesn't exist or has expired. """ if not key.startswith(API_KEY_PREFIX): return None key_hash = hash_token(key) api_key = self.db.query(APIKey).filter(APIKey.key_hash == key_hash).first() if not api_key: return None # Check expiration if api_key.expires_at and api_key.expires_at < datetime.utcnow(): return None # Update last used time api_key.last_used = datetime.utcnow() self.db.commit() return api_key def get_api_key_by_id(self, key_id: str) -> Optional[APIKey]: """Get an API key by its ID.""" return self.db.query(APIKey).filter(APIKey.id == key_id).first() def list_user_api_keys(self, user: User) -> list[APIKey]: """List all API keys for a user.""" return ( self.db.query(APIKey) .filter(APIKey.owner_id == user.id) .order_by(APIKey.created_at.desc()) .all() ) def delete_api_key(self, api_key: APIKey) -> None: """Delete an API key.""" self.db.delete(api_key) self.db.commit() def get_user_from_api_key(self, key: str) -> Optional[User]: """Get the user associated with an API key. Returns None if the key is invalid or the user is inactive. """ api_key = self.get_api_key_by_key(key) if not api_key: return None if not api_key.owner_id: return None user = self.db.query(User).filter(User.id == api_key.owner_id).first() if not user or not user.is_active: return None return user def create_default_admin(db: Session) -> Optional[User]: """Create the default admin user if no users exist. Returns the created user, or None if users already exist. """ # Check if any users exist user_count = db.query(User).count() if user_count > 0: return None # Create default admin auth_service = AuthService(db) admin = auth_service.create_user( username="admin", password="admin", is_admin=True, must_change_password=True, ) return admin # --- FastAPI Dependencies --- from fastapi import Depends, HTTPException, status, Cookie, Header from .database import get_db # Cookie name for session token SESSION_COOKIE_NAME = "orchard_session" def get_current_user_optional( db: Session = Depends(get_db), session_token: Optional[str] = Cookie(None, alias=SESSION_COOKIE_NAME), authorization: Optional[str] = Header(None), ) -> Optional[User]: """Get the current user from session cookie or API key. Returns None if no valid authentication is provided. Does not raise an exception for unauthenticated requests. """ auth_service = AuthService(db) # First try session cookie (web UI) if session_token: session = auth_service.get_session_by_token(session_token) if session: user = auth_service.get_user_by_id(str(session.user_id)) if user and user.is_active: return user # Then try API key (CLI/programmatic access) if authorization: if authorization.startswith("Bearer "): api_key = authorization[7:] # Remove "Bearer " prefix user = auth_service.get_user_from_api_key(api_key) if user: return user return None def get_current_user( user: Optional[User] = Depends(get_current_user_optional), ) -> User: """Get the current authenticated user. Raises HTTPException 401 if not authenticated. """ if not user: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Not authenticated", headers={"WWW-Authenticate": "Bearer"}, ) return user def require_admin( user: User = Depends(get_current_user), ) -> User: """Require the current user to be an admin. Raises HTTPException 403 if user is not an admin. """ if not user.is_admin: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Admin privileges required", ) return user def get_auth_service(db: Session = Depends(get_db)) -> AuthService: """Get an AuthService instance.""" return AuthService(db)