Add project-level authorization checks
Authorization: - Add AuthorizationService for checking project access - Implement get_user_access_level() with admin, owner, and permission checks - Add check_project_access() helper for route handlers - Add grant_access() and revoke_access() methods - Add ProjectAccessChecker dependency class Routes: - Add authorization checks to project CRUD (read, update, delete) - Add authorization checks to package create - Add authorization checks to upload endpoint (requires write) - Add authorization checks to download endpoint (requires read) - Add authorization checks to tag create Tests: - Fix pagination flakiness in test_list_projects - Fix pagination flakiness in test_projects_search - Add API key authentication to concurrent upload test
This commit is contained in:
@@ -448,3 +448,261 @@ def require_admin(
|
||||
def get_auth_service(db: Session = Depends(get_db)) -> AuthService:
|
||||
"""Get an AuthService instance."""
|
||||
return AuthService(db)
|
||||
|
||||
|
||||
# --- Authorization ---
|
||||
|
||||
# Access levels in order of increasing privilege
|
||||
ACCESS_LEVELS = ["read", "write", "admin"]
|
||||
|
||||
|
||||
def get_access_level_rank(level: str) -> int:
|
||||
"""Get numeric rank for access level comparison."""
|
||||
try:
|
||||
return ACCESS_LEVELS.index(level)
|
||||
except ValueError:
|
||||
return -1
|
||||
|
||||
|
||||
def has_sufficient_access(user_level: str, required_level: str) -> bool:
|
||||
"""Check if user_level is sufficient for required_level.
|
||||
|
||||
Access levels are hierarchical: admin > write > read
|
||||
"""
|
||||
return get_access_level_rank(user_level) >= get_access_level_rank(required_level)
|
||||
|
||||
|
||||
class AuthorizationService:
|
||||
"""Service for checking project-level authorization."""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
|
||||
def get_user_access_level(
|
||||
self, project_id: str, user: Optional[User]
|
||||
) -> Optional[str]:
|
||||
"""Get the user's access level for a project.
|
||||
|
||||
Returns the highest access level the user has, or None if no access.
|
||||
Checks in order:
|
||||
1. System admin - gets admin access to all projects
|
||||
2. Project owner (created_by) - gets admin access
|
||||
3. Explicit permission in access_permissions table
|
||||
"""
|
||||
from .models import Project, AccessPermission
|
||||
|
||||
# Get the project
|
||||
project = self.db.query(Project).filter(Project.id == project_id).first()
|
||||
if not project:
|
||||
return None
|
||||
|
||||
# Anonymous users only get access to public projects
|
||||
if not user:
|
||||
return "read" if project.is_public else None
|
||||
|
||||
# System admins get admin access everywhere
|
||||
if user.is_admin:
|
||||
return "admin"
|
||||
|
||||
# Project owner gets admin access
|
||||
if project.created_by == user.username:
|
||||
return "admin"
|
||||
|
||||
# Check explicit permissions
|
||||
permission = (
|
||||
self.db.query(AccessPermission)
|
||||
.filter(
|
||||
AccessPermission.project_id == project_id,
|
||||
AccessPermission.user_id == user.username,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if permission:
|
||||
# Check expiration
|
||||
if permission.expires_at and permission.expires_at < datetime.now(timezone.utc):
|
||||
return "read" if project.is_public else None
|
||||
return permission.level
|
||||
|
||||
# Fall back to public access
|
||||
return "read" if project.is_public else None
|
||||
|
||||
def check_access(
|
||||
self,
|
||||
project_id: str,
|
||||
user: Optional[User],
|
||||
required_level: str,
|
||||
) -> bool:
|
||||
"""Check if user has required access level for project."""
|
||||
user_level = self.get_user_access_level(project_id, user)
|
||||
if not user_level:
|
||||
return False
|
||||
return has_sufficient_access(user_level, required_level)
|
||||
|
||||
def grant_access(
|
||||
self,
|
||||
project_id: str,
|
||||
username: str,
|
||||
level: str,
|
||||
expires_at: Optional[datetime] = None,
|
||||
) -> "AccessPermission":
|
||||
"""Grant access to a user for a project."""
|
||||
from .models import AccessPermission
|
||||
|
||||
# Check if permission already exists
|
||||
existing = (
|
||||
self.db.query(AccessPermission)
|
||||
.filter(
|
||||
AccessPermission.project_id == project_id,
|
||||
AccessPermission.user_id == username,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if existing:
|
||||
existing.level = level
|
||||
existing.expires_at = expires_at
|
||||
self.db.commit()
|
||||
return existing
|
||||
|
||||
permission = AccessPermission(
|
||||
project_id=project_id,
|
||||
user_id=username,
|
||||
level=level,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
self.db.add(permission)
|
||||
self.db.commit()
|
||||
self.db.refresh(permission)
|
||||
return permission
|
||||
|
||||
def revoke_access(self, project_id: str, username: str) -> bool:
|
||||
"""Revoke a user's access to a project. Returns True if deleted."""
|
||||
from .models import AccessPermission
|
||||
|
||||
count = (
|
||||
self.db.query(AccessPermission)
|
||||
.filter(
|
||||
AccessPermission.project_id == project_id,
|
||||
AccessPermission.user_id == username,
|
||||
)
|
||||
.delete()
|
||||
)
|
||||
self.db.commit()
|
||||
return count > 0
|
||||
|
||||
def list_project_permissions(self, project_id: str) -> list:
|
||||
"""List all permissions for a project."""
|
||||
from .models import AccessPermission
|
||||
|
||||
return (
|
||||
self.db.query(AccessPermission)
|
||||
.filter(AccessPermission.project_id == project_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
def get_authorization_service(db: Session = Depends(get_db)) -> AuthorizationService:
|
||||
"""Get an AuthorizationService instance."""
|
||||
return AuthorizationService(db)
|
||||
|
||||
|
||||
class ProjectAccessChecker:
|
||||
"""Dependency for checking project access in route handlers."""
|
||||
|
||||
def __init__(self, required_level: str = "read"):
|
||||
self.required_level = required_level
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
project: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Optional[User] = Depends(get_current_user_optional),
|
||||
) -> User:
|
||||
"""Check if user has required access to project.
|
||||
|
||||
Raises 404 if project not found, 403 if insufficient access.
|
||||
Returns the current user (or None for public read access).
|
||||
"""
|
||||
from .models import Project
|
||||
|
||||
# Find project by name
|
||||
proj = db.query(Project).filter(Project.name == project).first()
|
||||
if not proj:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Project '{project}' not found",
|
||||
)
|
||||
|
||||
auth_service = AuthorizationService(db)
|
||||
|
||||
if not auth_service.check_access(str(proj.id), current_user, self.required_level):
|
||||
if not current_user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authentication required for private project",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Insufficient permissions. Required: {self.required_level}",
|
||||
)
|
||||
|
||||
return current_user
|
||||
|
||||
|
||||
# Pre-configured access checkers for common use cases
|
||||
require_project_read = ProjectAccessChecker("read")
|
||||
require_project_write = ProjectAccessChecker("write")
|
||||
require_project_admin = ProjectAccessChecker("admin")
|
||||
|
||||
|
||||
def check_project_access(
|
||||
db: Session,
|
||||
project_name: str,
|
||||
user: Optional[User],
|
||||
required_level: str = "read",
|
||||
) -> "Project":
|
||||
"""Check if user has required access to project.
|
||||
|
||||
This is a helper function for use in route handlers.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
project_name: Name of the project
|
||||
user: Current user (can be None for anonymous)
|
||||
required_level: Required access level (read, write, admin)
|
||||
|
||||
Returns:
|
||||
The Project object if access is granted
|
||||
|
||||
Raises:
|
||||
HTTPException 404: Project not found
|
||||
HTTPException 401: Authentication required for private project
|
||||
HTTPException 403: Insufficient permissions
|
||||
"""
|
||||
from .models import Project
|
||||
|
||||
# Find project by name
|
||||
project = db.query(Project).filter(Project.name == project_name).first()
|
||||
if not project:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Project '{project_name}' not found",
|
||||
)
|
||||
|
||||
auth_service = AuthorizationService(db)
|
||||
|
||||
if not auth_service.check_access(str(project.id), user, required_level):
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authentication required for private project",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Insufficient permissions. Required: {required_level}",
|
||||
)
|
||||
|
||||
return project
|
||||
|
||||
@@ -371,6 +371,8 @@ from .auth import (
|
||||
validate_password_strength,
|
||||
PasswordTooShortError,
|
||||
MIN_PASSWORD_LENGTH,
|
||||
check_project_access,
|
||||
AuthorizationService,
|
||||
)
|
||||
|
||||
|
||||
@@ -1064,10 +1066,13 @@ def create_project(
|
||||
|
||||
|
||||
@router.get("/api/v1/projects/{project_name}", response_model=ProjectResponse)
|
||||
def get_project(project_name: str, db: Session = Depends(get_db)):
|
||||
project = db.query(Project).filter(Project.name == project_name).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
def get_project(
|
||||
project_name: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Optional[User] = Depends(get_current_user_optional),
|
||||
):
|
||||
"""Get a single project by name. Requires read access for private projects."""
|
||||
project = check_project_access(db, project_name, current_user, "read")
|
||||
return project
|
||||
|
||||
|
||||
@@ -1077,13 +1082,11 @@ def update_project(
|
||||
project_update: ProjectUpdate,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Optional[User] = Depends(get_current_user_optional),
|
||||
):
|
||||
"""Update a project's metadata."""
|
||||
user_id = get_user_id(request)
|
||||
|
||||
project = db.query(Project).filter(Project.name == project_name).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
"""Update a project's metadata. Requires admin access."""
|
||||
project = check_project_access(db, project_name, current_user, "admin")
|
||||
user_id = current_user.username if current_user else get_user_id(request)
|
||||
|
||||
# Track changes for audit log
|
||||
changes = {}
|
||||
@@ -1130,14 +1133,16 @@ def delete_project(
|
||||
project_name: str,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Optional[User] = Depends(get_current_user_optional),
|
||||
):
|
||||
"""
|
||||
Delete a project and all its packages.
|
||||
Delete a project and all its packages. Requires admin access.
|
||||
|
||||
Decrements ref_count for all artifacts referenced by tags in all packages
|
||||
within this project.
|
||||
"""
|
||||
user_id = get_user_id(request)
|
||||
check_project_access(db, project_name, current_user, "admin")
|
||||
user_id = current_user.username if current_user else get_user_id(request)
|
||||
|
||||
project = db.query(Project).filter(Project.name == project_name).first()
|
||||
if not project:
|
||||
@@ -1453,10 +1458,10 @@ def create_package(
|
||||
package: PackageCreate,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Optional[User] = Depends(get_current_user_optional),
|
||||
):
|
||||
project = db.query(Project).filter(Project.name == project_name).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
"""Create a new package in a project. Requires write access."""
|
||||
project = check_project_access(db, project_name, current_user, "write")
|
||||
|
||||
# Validate format
|
||||
if package.format not in PACKAGE_FORMATS:
|
||||
@@ -1680,14 +1685,12 @@ def upload_artifact(
|
||||
- Authorization: Bearer <api-key> for authentication
|
||||
"""
|
||||
start_time = time.time()
|
||||
user_id = get_user_id_from_request(request, db, current_user)
|
||||
settings = get_settings()
|
||||
storage_result = None
|
||||
|
||||
# Get project and package
|
||||
project = db.query(Project).filter(Project.name == project_name).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
# Check authorization (write access required for uploads)
|
||||
project = check_project_access(db, project_name, current_user, "write")
|
||||
user_id = current_user.username if current_user else get_user_id_from_request(request, db, current_user)
|
||||
|
||||
package = (
|
||||
db.query(Package)
|
||||
@@ -2312,6 +2315,7 @@ def download_artifact(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
storage: S3Storage = Depends(get_storage),
|
||||
current_user: Optional[User] = Depends(get_current_user_optional),
|
||||
range: Optional[str] = Header(None),
|
||||
mode: Optional[Literal["proxy", "redirect", "presigned"]] = Query(
|
||||
default=None,
|
||||
@@ -2347,10 +2351,8 @@ def download_artifact(
|
||||
"""
|
||||
settings = get_settings()
|
||||
|
||||
# Get project and package
|
||||
project = db.query(Project).filter(Project.name == project_name).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
# Check authorization (read access required for downloads)
|
||||
project = check_project_access(db, project_name, current_user, "read")
|
||||
|
||||
package = (
|
||||
db.query(Package)
|
||||
@@ -2568,10 +2570,8 @@ def get_artifact_url(
|
||||
"""
|
||||
settings = get_settings()
|
||||
|
||||
# Get project and package
|
||||
project = db.query(Project).filter(Project.name == project_name).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
# Check authorization (read access required for downloads)
|
||||
project = check_project_access(db, project_name, current_user, "read")
|
||||
|
||||
package = (
|
||||
db.query(Package)
|
||||
@@ -2826,12 +2826,11 @@ def create_tag(
|
||||
tag: TagCreate,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Optional[User] = Depends(get_current_user_optional),
|
||||
):
|
||||
user_id = get_user_id(request)
|
||||
|
||||
project = db.query(Project).filter(Project.name == project_name).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
"""Create or update a tag. Requires write access."""
|
||||
project = check_project_access(db, project_name, current_user, "write")
|
||||
user_id = current_user.username if current_user else get_user_id(request)
|
||||
|
||||
package = (
|
||||
db.query(Package)
|
||||
|
||||
@@ -182,9 +182,10 @@ def test_app():
|
||||
@pytest.fixture
|
||||
def integration_client():
|
||||
"""
|
||||
Create a test client for integration tests.
|
||||
Create an authenticated test client for integration tests.
|
||||
|
||||
Uses the real database and MinIO from docker-compose.local.yml.
|
||||
Authenticates as admin for write operations.
|
||||
"""
|
||||
from httpx import Client
|
||||
|
||||
@@ -192,6 +193,15 @@ def integration_client():
|
||||
base_url = os.environ.get("ORCHARD_TEST_URL", "http://localhost:8080")
|
||||
|
||||
with Client(base_url=base_url, timeout=30.0) as client:
|
||||
# Login as admin to enable write operations
|
||||
login_response = client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"username": "admin", "password": "changeme123"},
|
||||
)
|
||||
# If login fails, tests will fail - that's expected if auth is broken
|
||||
if login_response.status_code != 200:
|
||||
# Try to continue without auth for backward compatibility
|
||||
pass
|
||||
yield client
|
||||
|
||||
|
||||
|
||||
@@ -59,7 +59,8 @@ class TestProjectCRUD:
|
||||
@pytest.mark.integration
|
||||
def test_list_projects(self, integration_client, test_project):
|
||||
"""Test listing projects includes created project."""
|
||||
response = integration_client.get("/api/v1/projects")
|
||||
# Search specifically for our test project to avoid pagination issues
|
||||
response = integration_client.get(f"/api/v1/projects?search={test_project}")
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
@@ -107,9 +108,11 @@ class TestProjectListingFilters:
|
||||
@pytest.mark.integration
|
||||
def test_projects_search(self, integration_client, test_project):
|
||||
"""Test project search by name."""
|
||||
# Search for our test project
|
||||
# Search using the unique portion of our test project name
|
||||
# test_project format is "test-project-test-{uuid[:8]}"
|
||||
unique_part = test_project.split("-")[-1] # Get the UUID portion
|
||||
response = integration_client.get(
|
||||
f"/api/v1/projects?search={test_project[:10]}"
|
||||
f"/api/v1/projects?search={unique_part}"
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
@@ -286,6 +286,14 @@ class TestConcurrentUploads:
|
||||
expected_hash = compute_sha256(content)
|
||||
num_concurrent = 5
|
||||
|
||||
# Create an API key for worker threads
|
||||
api_key_response = integration_client.post(
|
||||
"/api/v1/auth/keys",
|
||||
json={"name": "concurrent-test-key"},
|
||||
)
|
||||
assert api_key_response.status_code == 200, f"Failed to create API key: {api_key_response.text}"
|
||||
api_key = api_key_response.json()["key"]
|
||||
|
||||
results = []
|
||||
errors = []
|
||||
|
||||
@@ -306,6 +314,7 @@ class TestConcurrentUploads:
|
||||
f"/api/v1/project/{project}/{package}/upload",
|
||||
files=files,
|
||||
data={"tag": f"concurrent-{tag_suffix}"},
|
||||
headers={"Authorization": f"Bearer {api_key}"},
|
||||
)
|
||||
if response.status_code == 200:
|
||||
results.append(response.json())
|
||||
|
||||
Reference in New Issue
Block a user