Add project-level authorization checks
Authorization: - Add AuthorizationService for checking project access - Implement get_user_access_level() with admin, owner, and permission checks - Add check_project_access() helper for route handlers - Add grant_access() and revoke_access() methods - Add ProjectAccessChecker dependency class Routes: - Add authorization checks to project CRUD (read, update, delete) - Add authorization checks to package create - Add authorization checks to upload endpoint (requires write) - Add authorization checks to download endpoint (requires read) - Add authorization checks to tag create Tests: - Fix pagination flakiness in test_list_projects - Fix pagination flakiness in test_projects_search - Add API key authentication to concurrent upload test
This commit is contained in:
@@ -448,3 +448,261 @@ def require_admin(
|
||||
def get_auth_service(db: Session = Depends(get_db)) -> AuthService:
|
||||
"""Get an AuthService instance."""
|
||||
return AuthService(db)
|
||||
|
||||
|
||||
# --- Authorization ---
|
||||
|
||||
# Access levels in order of increasing privilege
|
||||
ACCESS_LEVELS = ["read", "write", "admin"]
|
||||
|
||||
|
||||
def get_access_level_rank(level: str) -> int:
|
||||
"""Get numeric rank for access level comparison."""
|
||||
try:
|
||||
return ACCESS_LEVELS.index(level)
|
||||
except ValueError:
|
||||
return -1
|
||||
|
||||
|
||||
def has_sufficient_access(user_level: str, required_level: str) -> bool:
|
||||
"""Check if user_level is sufficient for required_level.
|
||||
|
||||
Access levels are hierarchical: admin > write > read
|
||||
"""
|
||||
return get_access_level_rank(user_level) >= get_access_level_rank(required_level)
|
||||
|
||||
|
||||
class AuthorizationService:
|
||||
"""Service for checking project-level authorization."""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
|
||||
def get_user_access_level(
|
||||
self, project_id: str, user: Optional[User]
|
||||
) -> Optional[str]:
|
||||
"""Get the user's access level for a project.
|
||||
|
||||
Returns the highest access level the user has, or None if no access.
|
||||
Checks in order:
|
||||
1. System admin - gets admin access to all projects
|
||||
2. Project owner (created_by) - gets admin access
|
||||
3. Explicit permission in access_permissions table
|
||||
"""
|
||||
from .models import Project, AccessPermission
|
||||
|
||||
# Get the project
|
||||
project = self.db.query(Project).filter(Project.id == project_id).first()
|
||||
if not project:
|
||||
return None
|
||||
|
||||
# Anonymous users only get access to public projects
|
||||
if not user:
|
||||
return "read" if project.is_public else None
|
||||
|
||||
# System admins get admin access everywhere
|
||||
if user.is_admin:
|
||||
return "admin"
|
||||
|
||||
# Project owner gets admin access
|
||||
if project.created_by == user.username:
|
||||
return "admin"
|
||||
|
||||
# Check explicit permissions
|
||||
permission = (
|
||||
self.db.query(AccessPermission)
|
||||
.filter(
|
||||
AccessPermission.project_id == project_id,
|
||||
AccessPermission.user_id == user.username,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if permission:
|
||||
# Check expiration
|
||||
if permission.expires_at and permission.expires_at < datetime.now(timezone.utc):
|
||||
return "read" if project.is_public else None
|
||||
return permission.level
|
||||
|
||||
# Fall back to public access
|
||||
return "read" if project.is_public else None
|
||||
|
||||
def check_access(
|
||||
self,
|
||||
project_id: str,
|
||||
user: Optional[User],
|
||||
required_level: str,
|
||||
) -> bool:
|
||||
"""Check if user has required access level for project."""
|
||||
user_level = self.get_user_access_level(project_id, user)
|
||||
if not user_level:
|
||||
return False
|
||||
return has_sufficient_access(user_level, required_level)
|
||||
|
||||
def grant_access(
|
||||
self,
|
||||
project_id: str,
|
||||
username: str,
|
||||
level: str,
|
||||
expires_at: Optional[datetime] = None,
|
||||
) -> "AccessPermission":
|
||||
"""Grant access to a user for a project."""
|
||||
from .models import AccessPermission
|
||||
|
||||
# Check if permission already exists
|
||||
existing = (
|
||||
self.db.query(AccessPermission)
|
||||
.filter(
|
||||
AccessPermission.project_id == project_id,
|
||||
AccessPermission.user_id == username,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if existing:
|
||||
existing.level = level
|
||||
existing.expires_at = expires_at
|
||||
self.db.commit()
|
||||
return existing
|
||||
|
||||
permission = AccessPermission(
|
||||
project_id=project_id,
|
||||
user_id=username,
|
||||
level=level,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
self.db.add(permission)
|
||||
self.db.commit()
|
||||
self.db.refresh(permission)
|
||||
return permission
|
||||
|
||||
def revoke_access(self, project_id: str, username: str) -> bool:
|
||||
"""Revoke a user's access to a project. Returns True if deleted."""
|
||||
from .models import AccessPermission
|
||||
|
||||
count = (
|
||||
self.db.query(AccessPermission)
|
||||
.filter(
|
||||
AccessPermission.project_id == project_id,
|
||||
AccessPermission.user_id == username,
|
||||
)
|
||||
.delete()
|
||||
)
|
||||
self.db.commit()
|
||||
return count > 0
|
||||
|
||||
def list_project_permissions(self, project_id: str) -> list:
|
||||
"""List all permissions for a project."""
|
||||
from .models import AccessPermission
|
||||
|
||||
return (
|
||||
self.db.query(AccessPermission)
|
||||
.filter(AccessPermission.project_id == project_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
def get_authorization_service(db: Session = Depends(get_db)) -> AuthorizationService:
|
||||
"""Get an AuthorizationService instance."""
|
||||
return AuthorizationService(db)
|
||||
|
||||
|
||||
class ProjectAccessChecker:
|
||||
"""Dependency for checking project access in route handlers."""
|
||||
|
||||
def __init__(self, required_level: str = "read"):
|
||||
self.required_level = required_level
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
project: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Optional[User] = Depends(get_current_user_optional),
|
||||
) -> User:
|
||||
"""Check if user has required access to project.
|
||||
|
||||
Raises 404 if project not found, 403 if insufficient access.
|
||||
Returns the current user (or None for public read access).
|
||||
"""
|
||||
from .models import Project
|
||||
|
||||
# Find project by name
|
||||
proj = db.query(Project).filter(Project.name == project).first()
|
||||
if not proj:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Project '{project}' not found",
|
||||
)
|
||||
|
||||
auth_service = AuthorizationService(db)
|
||||
|
||||
if not auth_service.check_access(str(proj.id), current_user, self.required_level):
|
||||
if not current_user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authentication required for private project",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Insufficient permissions. Required: {self.required_level}",
|
||||
)
|
||||
|
||||
return current_user
|
||||
|
||||
|
||||
# Pre-configured access checkers for common use cases
|
||||
require_project_read = ProjectAccessChecker("read")
|
||||
require_project_write = ProjectAccessChecker("write")
|
||||
require_project_admin = ProjectAccessChecker("admin")
|
||||
|
||||
|
||||
def check_project_access(
|
||||
db: Session,
|
||||
project_name: str,
|
||||
user: Optional[User],
|
||||
required_level: str = "read",
|
||||
) -> "Project":
|
||||
"""Check if user has required access to project.
|
||||
|
||||
This is a helper function for use in route handlers.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
project_name: Name of the project
|
||||
user: Current user (can be None for anonymous)
|
||||
required_level: Required access level (read, write, admin)
|
||||
|
||||
Returns:
|
||||
The Project object if access is granted
|
||||
|
||||
Raises:
|
||||
HTTPException 404: Project not found
|
||||
HTTPException 401: Authentication required for private project
|
||||
HTTPException 403: Insufficient permissions
|
||||
"""
|
||||
from .models import Project
|
||||
|
||||
# Find project by name
|
||||
project = db.query(Project).filter(Project.name == project_name).first()
|
||||
if not project:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Project '{project_name}' not found",
|
||||
)
|
||||
|
||||
auth_service = AuthorizationService(db)
|
||||
|
||||
if not auth_service.check_access(str(project.id), user, required_level):
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authentication required for private project",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Insufficient permissions. Required: {required_level}",
|
||||
)
|
||||
|
||||
return project
|
||||
|
||||
@@ -371,6 +371,8 @@ from .auth import (
|
||||
validate_password_strength,
|
||||
PasswordTooShortError,
|
||||
MIN_PASSWORD_LENGTH,
|
||||
check_project_access,
|
||||
AuthorizationService,
|
||||
)
|
||||
|
||||
|
||||
@@ -1064,10 +1066,13 @@ def create_project(
|
||||
|
||||
|
||||
@router.get("/api/v1/projects/{project_name}", response_model=ProjectResponse)
|
||||
def get_project(project_name: str, db: Session = Depends(get_db)):
|
||||
project = db.query(Project).filter(Project.name == project_name).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
def get_project(
|
||||
project_name: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Optional[User] = Depends(get_current_user_optional),
|
||||
):
|
||||
"""Get a single project by name. Requires read access for private projects."""
|
||||
project = check_project_access(db, project_name, current_user, "read")
|
||||
return project
|
||||
|
||||
|
||||
@@ -1077,13 +1082,11 @@ def update_project(
|
||||
project_update: ProjectUpdate,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Optional[User] = Depends(get_current_user_optional),
|
||||
):
|
||||
"""Update a project's metadata."""
|
||||
user_id = get_user_id(request)
|
||||
|
||||
project = db.query(Project).filter(Project.name == project_name).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
"""Update a project's metadata. Requires admin access."""
|
||||
project = check_project_access(db, project_name, current_user, "admin")
|
||||
user_id = current_user.username if current_user else get_user_id(request)
|
||||
|
||||
# Track changes for audit log
|
||||
changes = {}
|
||||
@@ -1130,14 +1133,16 @@ def delete_project(
|
||||
project_name: str,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Optional[User] = Depends(get_current_user_optional),
|
||||
):
|
||||
"""
|
||||
Delete a project and all its packages.
|
||||
Delete a project and all its packages. Requires admin access.
|
||||
|
||||
Decrements ref_count for all artifacts referenced by tags in all packages
|
||||
within this project.
|
||||
"""
|
||||
user_id = get_user_id(request)
|
||||
check_project_access(db, project_name, current_user, "admin")
|
||||
user_id = current_user.username if current_user else get_user_id(request)
|
||||
|
||||
project = db.query(Project).filter(Project.name == project_name).first()
|
||||
if not project:
|
||||
@@ -1453,10 +1458,10 @@ def create_package(
|
||||
package: PackageCreate,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Optional[User] = Depends(get_current_user_optional),
|
||||
):
|
||||
project = db.query(Project).filter(Project.name == project_name).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
"""Create a new package in a project. Requires write access."""
|
||||
project = check_project_access(db, project_name, current_user, "write")
|
||||
|
||||
# Validate format
|
||||
if package.format not in PACKAGE_FORMATS:
|
||||
@@ -1680,14 +1685,12 @@ def upload_artifact(
|
||||
- Authorization: Bearer <api-key> for authentication
|
||||
"""
|
||||
start_time = time.time()
|
||||
user_id = get_user_id_from_request(request, db, current_user)
|
||||
settings = get_settings()
|
||||
storage_result = None
|
||||
|
||||
# Get project and package
|
||||
project = db.query(Project).filter(Project.name == project_name).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
# Check authorization (write access required for uploads)
|
||||
project = check_project_access(db, project_name, current_user, "write")
|
||||
user_id = current_user.username if current_user else get_user_id_from_request(request, db, current_user)
|
||||
|
||||
package = (
|
||||
db.query(Package)
|
||||
@@ -2312,6 +2315,7 @@ def download_artifact(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
storage: S3Storage = Depends(get_storage),
|
||||
current_user: Optional[User] = Depends(get_current_user_optional),
|
||||
range: Optional[str] = Header(None),
|
||||
mode: Optional[Literal["proxy", "redirect", "presigned"]] = Query(
|
||||
default=None,
|
||||
@@ -2347,10 +2351,8 @@ def download_artifact(
|
||||
"""
|
||||
settings = get_settings()
|
||||
|
||||
# Get project and package
|
||||
project = db.query(Project).filter(Project.name == project_name).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
# Check authorization (read access required for downloads)
|
||||
project = check_project_access(db, project_name, current_user, "read")
|
||||
|
||||
package = (
|
||||
db.query(Package)
|
||||
@@ -2568,10 +2570,8 @@ def get_artifact_url(
|
||||
"""
|
||||
settings = get_settings()
|
||||
|
||||
# Get project and package
|
||||
project = db.query(Project).filter(Project.name == project_name).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
# Check authorization (read access required for downloads)
|
||||
project = check_project_access(db, project_name, current_user, "read")
|
||||
|
||||
package = (
|
||||
db.query(Package)
|
||||
@@ -2826,12 +2826,11 @@ def create_tag(
|
||||
tag: TagCreate,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Optional[User] = Depends(get_current_user_optional),
|
||||
):
|
||||
user_id = get_user_id(request)
|
||||
|
||||
project = db.query(Project).filter(Project.name == project_name).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
"""Create or update a tag. Requires write access."""
|
||||
project = check_project_access(db, project_name, current_user, "write")
|
||||
user_id = current_user.username if current_user else get_user_id(request)
|
||||
|
||||
package = (
|
||||
db.query(Package)
|
||||
|
||||
Reference in New Issue
Block a user