Add frontend access control enhancements and JWT support
- Hide New Project button for unauthenticated users, show login link - Add lock icon for private projects on home page - Show access level badges on project cards (Owner, Admin, Write, Read) - Add permission expiration date field to AccessManagement component - Add query timeout configuration for database (ORCHARD_DATABASE_QUERY_TIMEOUT) - Add JWT token validation support for external identity providers - Configurable via ORCHARD_JWT_* environment variables - Supports HS256 with secret or RS256 with JWKS - Auto-provisions users from JWT claims
This commit is contained in:
@@ -1,16 +1,20 @@
|
||||
"""Authentication service for Orchard.
|
||||
|
||||
Handles password hashing, session management, and API key operations.
|
||||
Handles password hashing, session management, API key operations, and JWT validation.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import secrets
|
||||
import logging
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Optional
|
||||
from passlib.context import CryptContext
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from .models import User, Session as UserSession, APIKey
|
||||
from .config import get_settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Password hashing context (bcrypt with cost factor 12)
|
||||
@@ -374,6 +378,147 @@ def create_default_admin(db: Session) -> Optional[User]:
|
||||
return admin
|
||||
|
||||
|
||||
# --- JWT Validation ---
|
||||
|
||||
|
||||
def validate_jwt_token(token: str) -> Optional[dict]:
|
||||
"""Validate a JWT token and return the decoded payload.
|
||||
|
||||
Returns None if validation fails or JWT is not configured.
|
||||
Uses python-jose for JWT operations.
|
||||
"""
|
||||
settings = get_settings()
|
||||
|
||||
if not settings.jwt_enabled:
|
||||
return None
|
||||
|
||||
try:
|
||||
from jose import jwt, JWTError, ExpiredSignatureError
|
||||
from jose.exceptions import JWTClaimsError
|
||||
except ImportError:
|
||||
logger.warning("python-jose not installed, JWT authentication disabled")
|
||||
return None
|
||||
|
||||
try:
|
||||
# Build decode options
|
||||
decode_options = {}
|
||||
|
||||
# Set up key for validation
|
||||
if settings.jwt_algorithm.startswith("RS"):
|
||||
# RS256/RS384/RS512 - use JWKS
|
||||
if not settings.jwt_jwks_url:
|
||||
logger.error("JWT JWKS URL not configured for RSA algorithm")
|
||||
return None
|
||||
|
||||
try:
|
||||
import httpx
|
||||
|
||||
# Fetch JWKS from the URL
|
||||
response = httpx.get(settings.jwt_jwks_url, timeout=10.0)
|
||||
response.raise_for_status()
|
||||
jwks = response.json()
|
||||
|
||||
# Get the key ID from the token header
|
||||
unverified_header = jwt.get_unverified_header(token)
|
||||
kid = unverified_header.get("kid")
|
||||
|
||||
# Find the matching key
|
||||
rsa_key = None
|
||||
for key in jwks.get("keys", []):
|
||||
if key.get("kid") == kid:
|
||||
rsa_key = key
|
||||
break
|
||||
|
||||
if not rsa_key:
|
||||
logger.error(f"No matching key found in JWKS for kid: {kid}")
|
||||
return None
|
||||
|
||||
key = rsa_key
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get signing key from JWKS: {e}")
|
||||
return None
|
||||
else:
|
||||
# HS256/HS384/HS512 - use secret
|
||||
if not settings.jwt_secret:
|
||||
logger.error("JWT secret not configured for HMAC algorithm")
|
||||
return None
|
||||
key = settings.jwt_secret
|
||||
|
||||
# Build decode kwargs
|
||||
decode_kwargs = {
|
||||
"algorithms": [settings.jwt_algorithm],
|
||||
"options": decode_options,
|
||||
}
|
||||
|
||||
# Add issuer validation if configured
|
||||
if settings.jwt_issuer:
|
||||
decode_kwargs["issuer"] = settings.jwt_issuer
|
||||
|
||||
# Add audience validation if configured
|
||||
if settings.jwt_audience:
|
||||
decode_kwargs["audience"] = settings.jwt_audience
|
||||
|
||||
# Decode and validate the token
|
||||
payload = jwt.decode(token, key, **decode_kwargs)
|
||||
return payload
|
||||
|
||||
except ExpiredSignatureError:
|
||||
logger.debug("JWT token expired")
|
||||
return None
|
||||
except JWTClaimsError as e:
|
||||
logger.debug(f"JWT claims error: {e}")
|
||||
return None
|
||||
except JWTError as e:
|
||||
logger.debug(f"Invalid JWT token: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"JWT validation error: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def get_or_create_user_from_jwt(db: Session, payload: dict) -> Optional[User]:
|
||||
"""Get or create a user from JWT payload.
|
||||
|
||||
Uses the configured username claim to extract the username.
|
||||
Creates a new user if one doesn't exist (for SSO auto-provisioning).
|
||||
"""
|
||||
settings = get_settings()
|
||||
username = payload.get(settings.jwt_username_claim)
|
||||
|
||||
if not username:
|
||||
logger.warning(f"JWT missing username claim: {settings.jwt_username_claim}")
|
||||
return None
|
||||
|
||||
# Sanitize username (remove domain from email if needed)
|
||||
if "@" in username and settings.jwt_username_claim == "email":
|
||||
# Keep full email as username for email-based auth
|
||||
pass
|
||||
|
||||
auth_service = AuthService(db)
|
||||
user = auth_service.get_user_by_username(username)
|
||||
|
||||
if user:
|
||||
if not user.is_active:
|
||||
logger.debug(f"JWT user {username} is inactive")
|
||||
return None
|
||||
return user
|
||||
|
||||
# Auto-provision user from JWT
|
||||
logger.info(f"Auto-provisioning user from JWT: {username}")
|
||||
try:
|
||||
user = auth_service.create_user(
|
||||
username=username,
|
||||
password=None, # No password for SSO users
|
||||
email=payload.get("email"),
|
||||
is_admin=False,
|
||||
must_change_password=False,
|
||||
)
|
||||
return user
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to auto-provision JWT user: {e}")
|
||||
return None
|
||||
|
||||
|
||||
# --- FastAPI Dependencies ---
|
||||
|
||||
from fastapi import Depends, HTTPException, status, Cookie, Header
|
||||
@@ -388,10 +533,15 @@ def get_current_user_optional(
|
||||
session_token: Optional[str] = Cookie(None, alias=SESSION_COOKIE_NAME),
|
||||
authorization: Optional[str] = Header(None),
|
||||
) -> Optional[User]:
|
||||
"""Get the current user from session cookie or API key.
|
||||
"""Get the current user from session cookie, API key, or JWT token.
|
||||
|
||||
Returns None if no valid authentication is provided.
|
||||
Does not raise an exception for unauthenticated requests.
|
||||
|
||||
Authentication methods are tried in order:
|
||||
1. Session cookie (web UI)
|
||||
2. API key (Bearer token starting with 'orch_')
|
||||
3. JWT token (Bearer token that's a valid JWT)
|
||||
"""
|
||||
auth_service = AuthService(db)
|
||||
|
||||
@@ -403,13 +553,24 @@ def get_current_user_optional(
|
||||
if user and user.is_active:
|
||||
return user
|
||||
|
||||
# Then try API key (CLI/programmatic access)
|
||||
if authorization:
|
||||
if authorization.startswith("Bearer "):
|
||||
api_key = authorization[7:] # Remove "Bearer " prefix
|
||||
user = auth_service.get_user_from_api_key(api_key)
|
||||
# Then try Bearer token (API key or JWT)
|
||||
if authorization and authorization.startswith("Bearer "):
|
||||
token = authorization[7:] # Remove "Bearer " prefix
|
||||
|
||||
# Check if it's an API key (starts with orch_)
|
||||
if token.startswith(API_KEY_PREFIX):
|
||||
user = auth_service.get_user_from_api_key(token)
|
||||
if user:
|
||||
return user
|
||||
else:
|
||||
# Try JWT validation
|
||||
settings = get_settings()
|
||||
if settings.jwt_enabled:
|
||||
payload = validate_jwt_token(token)
|
||||
if payload:
|
||||
user = get_or_create_user_from_jwt(db, payload)
|
||||
if user:
|
||||
return user
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@@ -25,6 +25,7 @@ class Settings(BaseSettings):
|
||||
database_pool_recycle: int = (
|
||||
1800 # Recycle connections after this many seconds (30 min)
|
||||
)
|
||||
database_query_timeout: int = 30 # Query timeout in seconds (0 = no timeout)
|
||||
|
||||
# S3
|
||||
s3_endpoint: str = ""
|
||||
@@ -52,6 +53,17 @@ class Settings(BaseSettings):
|
||||
log_level: str = "INFO" # DEBUG, INFO, WARNING, ERROR, CRITICAL
|
||||
log_format: str = "auto" # "json", "standard", or "auto" (json in production)
|
||||
|
||||
# JWT Authentication settings (optional, for external identity providers)
|
||||
jwt_enabled: bool = False # Enable JWT token validation
|
||||
jwt_secret: str = "" # Secret key for HS256, or leave empty for RS256 with JWKS
|
||||
jwt_algorithm: str = "HS256" # HS256 or RS256
|
||||
jwt_issuer: str = "" # Expected issuer (iss claim), leave empty to skip validation
|
||||
jwt_audience: str = "" # Expected audience (aud claim), leave empty to skip validation
|
||||
jwt_jwks_url: str = "" # JWKS URL for RS256 (e.g., https://auth.example.com/.well-known/jwks.json)
|
||||
jwt_username_claim: str = (
|
||||
"sub" # JWT claim to use as username (sub, email, preferred_username, etc.)
|
||||
)
|
||||
|
||||
@property
|
||||
def database_url(self) -> str:
|
||||
sslmode = f"?sslmode={self.database_sslmode}" if self.database_sslmode else ""
|
||||
|
||||
@@ -12,6 +12,12 @@ from .models import Base
|
||||
settings = get_settings()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Build connect_args with query timeout if configured
|
||||
connect_args = {}
|
||||
if settings.database_query_timeout > 0:
|
||||
# PostgreSQL statement_timeout is in milliseconds
|
||||
connect_args["options"] = f"-c statement_timeout={settings.database_query_timeout * 1000}"
|
||||
|
||||
# Create engine with connection pool configuration
|
||||
engine = create_engine(
|
||||
settings.database_url,
|
||||
@@ -21,6 +27,7 @@ engine = create_engine(
|
||||
max_overflow=settings.database_max_overflow,
|
||||
pool_timeout=settings.database_pool_timeout,
|
||||
pool_recycle=settings.database_pool_recycle,
|
||||
connect_args=connect_args,
|
||||
)
|
||||
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
@@ -45,11 +45,13 @@ from .models import (
|
||||
Consumer,
|
||||
AuditLog,
|
||||
User,
|
||||
AccessPermission,
|
||||
)
|
||||
from .schemas import (
|
||||
ProjectCreate,
|
||||
ProjectUpdate,
|
||||
ProjectResponse,
|
||||
ProjectWithAccessResponse,
|
||||
PackageCreate,
|
||||
PackageUpdate,
|
||||
PackageResponse,
|
||||
@@ -947,7 +949,7 @@ def global_search(
|
||||
|
||||
|
||||
# Project routes
|
||||
@router.get("/api/v1/projects", response_model=PaginatedResponse[ProjectResponse])
|
||||
@router.get("/api/v1/projects", response_model=PaginatedResponse[ProjectWithAccessResponse])
|
||||
def list_projects(
|
||||
request: Request,
|
||||
page: int = Query(default=1, ge=1, description="Page number"),
|
||||
@@ -963,8 +965,9 @@ def list_projects(
|
||||
),
|
||||
order: str = Query(default="asc", description="Sort order (asc, desc)"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Optional[User] = Depends(get_current_user_optional),
|
||||
):
|
||||
user_id = get_user_id(request)
|
||||
user_id = current_user.username if current_user else get_user_id(request)
|
||||
|
||||
# Validate sort field
|
||||
valid_sort_fields = {
|
||||
@@ -1022,8 +1025,51 @@ def list_projects(
|
||||
# Calculate total pages
|
||||
total_pages = math.ceil(total / limit) if total > 0 else 1
|
||||
|
||||
# Build access level info for each project
|
||||
project_ids = [p.id for p in projects]
|
||||
access_map = {}
|
||||
|
||||
if current_user and project_ids:
|
||||
# Get access permissions for this user across these projects
|
||||
permissions = (
|
||||
db.query(AccessPermission)
|
||||
.filter(
|
||||
AccessPermission.project_id.in_(project_ids),
|
||||
AccessPermission.user_id == current_user.username,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
access_map = {p.project_id: p.level for p in permissions}
|
||||
|
||||
# Build response with access levels
|
||||
items = []
|
||||
for p in projects:
|
||||
is_owner = p.created_by == user_id
|
||||
access_level = None
|
||||
|
||||
if is_owner:
|
||||
access_level = "admin"
|
||||
elif p.id in access_map:
|
||||
access_level = access_map[p.id]
|
||||
elif p.is_public:
|
||||
access_level = "read"
|
||||
|
||||
items.append(
|
||||
ProjectWithAccessResponse(
|
||||
id=p.id,
|
||||
name=p.name,
|
||||
description=p.description,
|
||||
is_public=p.is_public,
|
||||
created_at=p.created_at,
|
||||
updated_at=p.updated_at,
|
||||
created_by=p.created_by,
|
||||
access_level=access_level,
|
||||
is_owner=is_owner,
|
||||
)
|
||||
)
|
||||
|
||||
return PaginatedResponse(
|
||||
items=projects,
|
||||
items=items,
|
||||
pagination=PaginationMeta(
|
||||
page=page,
|
||||
limit=limit,
|
||||
|
||||
@@ -47,6 +47,13 @@ class ProjectUpdate(BaseModel):
|
||||
is_public: Optional[bool] = None
|
||||
|
||||
|
||||
class ProjectWithAccessResponse(ProjectResponse):
|
||||
"""Project response with user's access level included"""
|
||||
|
||||
access_level: Optional[str] = None # 'read', 'write', 'admin', or None
|
||||
is_owner: bool = False
|
||||
|
||||
|
||||
# Package format and platform enums
|
||||
PACKAGE_FORMATS = [
|
||||
"generic",
|
||||
|
||||
Reference in New Issue
Block a user