- Use timezone-aware datetimes (datetime.now(timezone.utc)) for session expiry comparison - Add explicit bcrypt==4.0.1 dependency for passlib bcrypt backend
413 lines
12 KiB
Python
413 lines
12 KiB
Python
"""Authentication service for Orchard.
|
|
|
|
Handles password hashing, session management, and API key operations.
|
|
"""
|
|
|
|
import hashlib
|
|
import secrets
|
|
from datetime import datetime, timedelta, timezone
|
|
from typing import Optional
|
|
from passlib.context import CryptContext
|
|
from sqlalchemy.orm import Session
|
|
|
|
from .models import User, Session as UserSession, APIKey
|
|
|
|
|
|
# Password hashing context (bcrypt with cost factor 12)
|
|
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
|
|
|
# API key prefix
|
|
API_KEY_PREFIX = "orch_"
|
|
|
|
# Session duration (24 hours default)
|
|
SESSION_DURATION_HOURS = 24
|
|
|
|
|
|
def hash_password(password: str) -> str:
|
|
"""Hash a password using bcrypt."""
|
|
return pwd_context.hash(password)
|
|
|
|
|
|
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
|
"""Verify a password against its hash."""
|
|
return pwd_context.verify(plain_password, hashed_password)
|
|
|
|
|
|
def hash_token(token: str) -> str:
|
|
"""Hash a token (session or API key) using SHA256."""
|
|
return hashlib.sha256(token.encode()).hexdigest()
|
|
|
|
|
|
def generate_session_token() -> str:
|
|
"""Generate a cryptographically secure session token."""
|
|
return secrets.token_urlsafe(32)
|
|
|
|
|
|
def generate_api_key() -> str:
|
|
"""Generate a new API key with prefix.
|
|
|
|
Format: orch_<32 random bytes as hex>
|
|
"""
|
|
random_part = secrets.token_hex(32)
|
|
return f"{API_KEY_PREFIX}{random_part}"
|
|
|
|
|
|
class AuthService:
|
|
"""Authentication service for user management and session handling."""
|
|
|
|
def __init__(self, db: Session):
|
|
self.db = db
|
|
|
|
# --- User Operations ---
|
|
|
|
def create_user(
|
|
self,
|
|
username: str,
|
|
password: Optional[str] = None,
|
|
email: Optional[str] = None,
|
|
is_admin: bool = False,
|
|
must_change_password: bool = False,
|
|
) -> User:
|
|
"""Create a new user account."""
|
|
user = User(
|
|
username=username,
|
|
password_hash=hash_password(password) if password else None,
|
|
email=email,
|
|
is_admin=is_admin,
|
|
must_change_password=must_change_password,
|
|
)
|
|
self.db.add(user)
|
|
self.db.commit()
|
|
self.db.refresh(user)
|
|
return user
|
|
|
|
def get_user_by_username(self, username: str) -> Optional[User]:
|
|
"""Get a user by username."""
|
|
return self.db.query(User).filter(User.username == username).first()
|
|
|
|
def get_user_by_id(self, user_id: str) -> Optional[User]:
|
|
"""Get a user by ID."""
|
|
return self.db.query(User).filter(User.id == user_id).first()
|
|
|
|
def authenticate_user(self, username: str, password: str) -> Optional[User]:
|
|
"""Authenticate a user with username and password.
|
|
|
|
Returns the user if authentication succeeds, None otherwise.
|
|
"""
|
|
user = self.get_user_by_username(username)
|
|
if not user:
|
|
return None
|
|
if not user.password_hash:
|
|
return None # OIDC-only user
|
|
if not user.is_active:
|
|
return None
|
|
if not verify_password(password, user.password_hash):
|
|
return None
|
|
return user
|
|
|
|
def change_password(self, user: User, new_password: str) -> None:
|
|
"""Change a user's password."""
|
|
user.password_hash = hash_password(new_password)
|
|
user.must_change_password = False
|
|
self.db.commit()
|
|
|
|
def update_last_login(self, user: User) -> None:
|
|
"""Update the user's last login timestamp."""
|
|
user.last_login = datetime.now(timezone.utc)
|
|
self.db.commit()
|
|
|
|
def list_users(self, include_inactive: bool = False) -> list[User]:
|
|
"""List all users."""
|
|
query = self.db.query(User)
|
|
if not include_inactive:
|
|
query = query.filter(User.is_active == True)
|
|
return query.order_by(User.username).all()
|
|
|
|
def set_user_active(self, user: User, is_active: bool) -> None:
|
|
"""Enable or disable a user account."""
|
|
user.is_active = is_active
|
|
self.db.commit()
|
|
|
|
def set_user_admin(self, user: User, is_admin: bool) -> None:
|
|
"""Grant or revoke admin privileges."""
|
|
user.is_admin = is_admin
|
|
self.db.commit()
|
|
|
|
def reset_user_password(self, user: User, new_password: str) -> None:
|
|
"""Reset a user's password (admin action)."""
|
|
user.password_hash = hash_password(new_password)
|
|
user.must_change_password = True
|
|
self.db.commit()
|
|
|
|
# --- Session Operations ---
|
|
|
|
def create_session(
|
|
self,
|
|
user: User,
|
|
user_agent: Optional[str] = None,
|
|
ip_address: Optional[str] = None,
|
|
) -> tuple[UserSession, str]:
|
|
"""Create a new session for a user.
|
|
|
|
Returns a tuple of (session, token) where token is the plaintext
|
|
token that should be sent to the client. The token is only returned
|
|
once and should be stored securely.
|
|
"""
|
|
token = generate_session_token()
|
|
token_hash = hash_token(token)
|
|
|
|
session = UserSession(
|
|
user_id=user.id,
|
|
token_hash=token_hash,
|
|
expires_at=datetime.now(timezone.utc) + timedelta(hours=SESSION_DURATION_HOURS),
|
|
user_agent=user_agent,
|
|
ip_address=ip_address,
|
|
)
|
|
self.db.add(session)
|
|
self.db.commit()
|
|
self.db.refresh(session)
|
|
|
|
return session, token
|
|
|
|
def get_session_by_token(self, token: str) -> Optional[UserSession]:
|
|
"""Get a session by its token.
|
|
|
|
Returns None if the session doesn't exist or has expired.
|
|
"""
|
|
token_hash = hash_token(token)
|
|
session = (
|
|
self.db.query(UserSession)
|
|
.filter(UserSession.token_hash == token_hash)
|
|
.first()
|
|
)
|
|
|
|
if not session:
|
|
return None
|
|
|
|
if session.expires_at < datetime.now(timezone.utc):
|
|
# Session has expired, delete it
|
|
self.db.delete(session)
|
|
self.db.commit()
|
|
return None
|
|
|
|
# Update last accessed time
|
|
session.last_accessed = datetime.now(timezone.utc)
|
|
self.db.commit()
|
|
|
|
return session
|
|
|
|
def delete_session(self, session: UserSession) -> None:
|
|
"""Delete a session (logout)."""
|
|
self.db.delete(session)
|
|
self.db.commit()
|
|
|
|
def delete_user_sessions(self, user: User) -> int:
|
|
"""Delete all sessions for a user. Returns count of deleted sessions."""
|
|
count = (
|
|
self.db.query(UserSession).filter(UserSession.user_id == user.id).delete()
|
|
)
|
|
self.db.commit()
|
|
return count
|
|
|
|
def cleanup_expired_sessions(self) -> int:
|
|
"""Delete all expired sessions. Returns count of deleted sessions."""
|
|
count = (
|
|
self.db.query(UserSession)
|
|
.filter(UserSession.expires_at < datetime.now(timezone.utc))
|
|
.delete()
|
|
)
|
|
self.db.commit()
|
|
return count
|
|
|
|
# --- API Key Operations ---
|
|
|
|
def create_api_key(
|
|
self,
|
|
user: User,
|
|
name: str,
|
|
description: Optional[str] = None,
|
|
scopes: Optional[list[str]] = None,
|
|
expires_at: Optional[datetime] = None,
|
|
) -> tuple[APIKey, str]:
|
|
"""Create a new API key for a user.
|
|
|
|
Returns a tuple of (api_key, key) where key is the plaintext
|
|
API key that should be sent to the client. The key is only returned
|
|
once and should be stored securely by the user.
|
|
"""
|
|
key = generate_api_key()
|
|
key_hash = hash_token(key)
|
|
|
|
api_key = APIKey(
|
|
key_hash=key_hash,
|
|
name=name,
|
|
user_id=user.username, # Legacy field
|
|
owner_id=user.id,
|
|
description=description,
|
|
scopes=scopes or ["read", "write"],
|
|
expires_at=expires_at,
|
|
)
|
|
self.db.add(api_key)
|
|
self.db.commit()
|
|
self.db.refresh(api_key)
|
|
|
|
return api_key, key
|
|
|
|
def get_api_key_by_key(self, key: str) -> Optional[APIKey]:
|
|
"""Get an API key by its plaintext key.
|
|
|
|
Returns None if the key doesn't exist or has expired.
|
|
"""
|
|
if not key.startswith(API_KEY_PREFIX):
|
|
return None
|
|
|
|
key_hash = hash_token(key)
|
|
api_key = self.db.query(APIKey).filter(APIKey.key_hash == key_hash).first()
|
|
|
|
if not api_key:
|
|
return None
|
|
|
|
# Check expiration
|
|
if api_key.expires_at and api_key.expires_at < datetime.now(timezone.utc):
|
|
return None
|
|
|
|
# Update last used time
|
|
api_key.last_used = datetime.now(timezone.utc)
|
|
self.db.commit()
|
|
|
|
return api_key
|
|
|
|
def get_api_key_by_id(self, key_id: str) -> Optional[APIKey]:
|
|
"""Get an API key by its ID."""
|
|
return self.db.query(APIKey).filter(APIKey.id == key_id).first()
|
|
|
|
def list_user_api_keys(self, user: User) -> list[APIKey]:
|
|
"""List all API keys for a user."""
|
|
return (
|
|
self.db.query(APIKey)
|
|
.filter(APIKey.owner_id == user.id)
|
|
.order_by(APIKey.created_at.desc())
|
|
.all()
|
|
)
|
|
|
|
def delete_api_key(self, api_key: APIKey) -> None:
|
|
"""Delete an API key."""
|
|
self.db.delete(api_key)
|
|
self.db.commit()
|
|
|
|
def get_user_from_api_key(self, key: str) -> Optional[User]:
|
|
"""Get the user associated with an API key.
|
|
|
|
Returns None if the key is invalid or the user is inactive.
|
|
"""
|
|
api_key = self.get_api_key_by_key(key)
|
|
if not api_key:
|
|
return None
|
|
|
|
if not api_key.owner_id:
|
|
return None
|
|
|
|
user = self.db.query(User).filter(User.id == api_key.owner_id).first()
|
|
if not user or not user.is_active:
|
|
return None
|
|
|
|
return user
|
|
|
|
|
|
def create_default_admin(db: Session) -> Optional[User]:
|
|
"""Create the default admin user if no users exist.
|
|
|
|
Returns the created user, or None if users already exist.
|
|
"""
|
|
# Check if any users exist
|
|
user_count = db.query(User).count()
|
|
if user_count > 0:
|
|
return None
|
|
|
|
# Create default admin
|
|
auth_service = AuthService(db)
|
|
admin = auth_service.create_user(
|
|
username="admin",
|
|
password="admin",
|
|
is_admin=True,
|
|
must_change_password=True,
|
|
)
|
|
|
|
return admin
|
|
|
|
|
|
# --- FastAPI Dependencies ---
|
|
|
|
from fastapi import Depends, HTTPException, status, Cookie, Header
|
|
from .database import get_db
|
|
|
|
# Cookie name for session token
|
|
SESSION_COOKIE_NAME = "orchard_session"
|
|
|
|
|
|
def get_current_user_optional(
|
|
db: Session = Depends(get_db),
|
|
session_token: Optional[str] = Cookie(None, alias=SESSION_COOKIE_NAME),
|
|
authorization: Optional[str] = Header(None),
|
|
) -> Optional[User]:
|
|
"""Get the current user from session cookie or API key.
|
|
|
|
Returns None if no valid authentication is provided.
|
|
Does not raise an exception for unauthenticated requests.
|
|
"""
|
|
auth_service = AuthService(db)
|
|
|
|
# First try session cookie (web UI)
|
|
if session_token:
|
|
session = auth_service.get_session_by_token(session_token)
|
|
if session:
|
|
user = auth_service.get_user_by_id(str(session.user_id))
|
|
if user and user.is_active:
|
|
return user
|
|
|
|
# Then try API key (CLI/programmatic access)
|
|
if authorization:
|
|
if authorization.startswith("Bearer "):
|
|
api_key = authorization[7:] # Remove "Bearer " prefix
|
|
user = auth_service.get_user_from_api_key(api_key)
|
|
if user:
|
|
return user
|
|
|
|
return None
|
|
|
|
|
|
def get_current_user(
|
|
user: Optional[User] = Depends(get_current_user_optional),
|
|
) -> User:
|
|
"""Get the current authenticated user.
|
|
|
|
Raises HTTPException 401 if not authenticated.
|
|
"""
|
|
if not user:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Not authenticated",
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|
)
|
|
return user
|
|
|
|
|
|
def require_admin(
|
|
user: User = Depends(get_current_user),
|
|
) -> User:
|
|
"""Require the current user to be an admin.
|
|
|
|
Raises HTTPException 403 if user is not an admin.
|
|
"""
|
|
if not user.is_admin:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail="Admin privileges required",
|
|
)
|
|
return user
|
|
|
|
|
|
def get_auth_service(db: Session = Depends(get_db)) -> AuthService:
|
|
"""Get an AuthService instance."""
|
|
return AuthService(db)
|