diff --git a/backend/app/auth.py b/backend/app/auth.py index 82607b3..b75da58 100644 --- a/backend/app/auth.py +++ b/backend/app/auth.py @@ -448,3 +448,261 @@ def require_admin( def get_auth_service(db: Session = Depends(get_db)) -> AuthService: """Get an AuthService instance.""" return AuthService(db) + + +# --- Authorization --- + +# Access levels in order of increasing privilege +ACCESS_LEVELS = ["read", "write", "admin"] + + +def get_access_level_rank(level: str) -> int: + """Get numeric rank for access level comparison.""" + try: + return ACCESS_LEVELS.index(level) + except ValueError: + return -1 + + +def has_sufficient_access(user_level: str, required_level: str) -> bool: + """Check if user_level is sufficient for required_level. + + Access levels are hierarchical: admin > write > read + """ + return get_access_level_rank(user_level) >= get_access_level_rank(required_level) + + +class AuthorizationService: + """Service for checking project-level authorization.""" + + def __init__(self, db: Session): + self.db = db + + def get_user_access_level( + self, project_id: str, user: Optional[User] + ) -> Optional[str]: + """Get the user's access level for a project. + + Returns the highest access level the user has, or None if no access. + Checks in order: + 1. System admin - gets admin access to all projects + 2. Project owner (created_by) - gets admin access + 3. Explicit permission in access_permissions table + """ + from .models import Project, AccessPermission + + # Get the project + project = self.db.query(Project).filter(Project.id == project_id).first() + if not project: + return None + + # Anonymous users only get access to public projects + if not user: + return "read" if project.is_public else None + + # System admins get admin access everywhere + if user.is_admin: + return "admin" + + # Project owner gets admin access + if project.created_by == user.username: + return "admin" + + # Check explicit permissions + permission = ( + self.db.query(AccessPermission) + .filter( + AccessPermission.project_id == project_id, + AccessPermission.user_id == user.username, + ) + .first() + ) + + if permission: + # Check expiration + if permission.expires_at and permission.expires_at < datetime.now(timezone.utc): + return "read" if project.is_public else None + return permission.level + + # Fall back to public access + return "read" if project.is_public else None + + def check_access( + self, + project_id: str, + user: Optional[User], + required_level: str, + ) -> bool: + """Check if user has required access level for project.""" + user_level = self.get_user_access_level(project_id, user) + if not user_level: + return False + return has_sufficient_access(user_level, required_level) + + def grant_access( + self, + project_id: str, + username: str, + level: str, + expires_at: Optional[datetime] = None, + ) -> "AccessPermission": + """Grant access to a user for a project.""" + from .models import AccessPermission + + # Check if permission already exists + existing = ( + self.db.query(AccessPermission) + .filter( + AccessPermission.project_id == project_id, + AccessPermission.user_id == username, + ) + .first() + ) + + if existing: + existing.level = level + existing.expires_at = expires_at + self.db.commit() + return existing + + permission = AccessPermission( + project_id=project_id, + user_id=username, + level=level, + expires_at=expires_at, + ) + self.db.add(permission) + self.db.commit() + self.db.refresh(permission) + return permission + + def revoke_access(self, project_id: str, username: str) -> bool: + """Revoke a user's access to a project. Returns True if deleted.""" + from .models import AccessPermission + + count = ( + self.db.query(AccessPermission) + .filter( + AccessPermission.project_id == project_id, + AccessPermission.user_id == username, + ) + .delete() + ) + self.db.commit() + return count > 0 + + def list_project_permissions(self, project_id: str) -> list: + """List all permissions for a project.""" + from .models import AccessPermission + + return ( + self.db.query(AccessPermission) + .filter(AccessPermission.project_id == project_id) + .all() + ) + + +def get_authorization_service(db: Session = Depends(get_db)) -> AuthorizationService: + """Get an AuthorizationService instance.""" + return AuthorizationService(db) + + +class ProjectAccessChecker: + """Dependency for checking project access in route handlers.""" + + def __init__(self, required_level: str = "read"): + self.required_level = required_level + + def __call__( + self, + project: str, + db: Session = Depends(get_db), + current_user: Optional[User] = Depends(get_current_user_optional), + ) -> User: + """Check if user has required access to project. + + Raises 404 if project not found, 403 if insufficient access. + Returns the current user (or None for public read access). + """ + from .models import Project + + # Find project by name + proj = db.query(Project).filter(Project.name == project).first() + if not proj: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Project '{project}' not found", + ) + + auth_service = AuthorizationService(db) + + if not auth_service.check_access(str(proj.id), current_user, self.required_level): + if not current_user: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Authentication required for private project", + headers={"WWW-Authenticate": "Bearer"}, + ) + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=f"Insufficient permissions. Required: {self.required_level}", + ) + + return current_user + + +# Pre-configured access checkers for common use cases +require_project_read = ProjectAccessChecker("read") +require_project_write = ProjectAccessChecker("write") +require_project_admin = ProjectAccessChecker("admin") + + +def check_project_access( + db: Session, + project_name: str, + user: Optional[User], + required_level: str = "read", +) -> "Project": + """Check if user has required access to project. + + This is a helper function for use in route handlers. + + Args: + db: Database session + project_name: Name of the project + user: Current user (can be None for anonymous) + required_level: Required access level (read, write, admin) + + Returns: + The Project object if access is granted + + Raises: + HTTPException 404: Project not found + HTTPException 401: Authentication required for private project + HTTPException 403: Insufficient permissions + """ + from .models import Project + + # Find project by name + project = db.query(Project).filter(Project.name == project_name).first() + if not project: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Project '{project_name}' not found", + ) + + auth_service = AuthorizationService(db) + + if not auth_service.check_access(str(project.id), user, required_level): + if not user: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Authentication required for private project", + headers={"WWW-Authenticate": "Bearer"}, + ) + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=f"Insufficient permissions. Required: {required_level}", + ) + + return project diff --git a/backend/app/routes.py b/backend/app/routes.py index d95c7bb..9fd1aa7 100644 --- a/backend/app/routes.py +++ b/backend/app/routes.py @@ -371,6 +371,8 @@ from .auth import ( validate_password_strength, PasswordTooShortError, MIN_PASSWORD_LENGTH, + check_project_access, + AuthorizationService, ) @@ -1064,10 +1066,13 @@ def create_project( @router.get("/api/v1/projects/{project_name}", response_model=ProjectResponse) -def get_project(project_name: str, db: Session = Depends(get_db)): - project = db.query(Project).filter(Project.name == project_name).first() - if not project: - raise HTTPException(status_code=404, detail="Project not found") +def get_project( + project_name: str, + db: Session = Depends(get_db), + current_user: Optional[User] = Depends(get_current_user_optional), +): + """Get a single project by name. Requires read access for private projects.""" + project = check_project_access(db, project_name, current_user, "read") return project @@ -1077,13 +1082,11 @@ def update_project( project_update: ProjectUpdate, request: Request, db: Session = Depends(get_db), + current_user: Optional[User] = Depends(get_current_user_optional), ): - """Update a project's metadata.""" - user_id = get_user_id(request) - - project = db.query(Project).filter(Project.name == project_name).first() - if not project: - raise HTTPException(status_code=404, detail="Project not found") + """Update a project's metadata. Requires admin access.""" + project = check_project_access(db, project_name, current_user, "admin") + user_id = current_user.username if current_user else get_user_id(request) # Track changes for audit log changes = {} @@ -1130,14 +1133,16 @@ def delete_project( project_name: str, request: Request, db: Session = Depends(get_db), + current_user: Optional[User] = Depends(get_current_user_optional), ): """ - Delete a project and all its packages. + Delete a project and all its packages. Requires admin access. Decrements ref_count for all artifacts referenced by tags in all packages within this project. """ - user_id = get_user_id(request) + check_project_access(db, project_name, current_user, "admin") + user_id = current_user.username if current_user else get_user_id(request) project = db.query(Project).filter(Project.name == project_name).first() if not project: @@ -1453,10 +1458,10 @@ def create_package( package: PackageCreate, request: Request, db: Session = Depends(get_db), + current_user: Optional[User] = Depends(get_current_user_optional), ): - project = db.query(Project).filter(Project.name == project_name).first() - if not project: - raise HTTPException(status_code=404, detail="Project not found") + """Create a new package in a project. Requires write access.""" + project = check_project_access(db, project_name, current_user, "write") # Validate format if package.format not in PACKAGE_FORMATS: @@ -1680,14 +1685,12 @@ def upload_artifact( - Authorization: Bearer for authentication """ start_time = time.time() - user_id = get_user_id_from_request(request, db, current_user) settings = get_settings() storage_result = None - # Get project and package - project = db.query(Project).filter(Project.name == project_name).first() - if not project: - raise HTTPException(status_code=404, detail="Project not found") + # Check authorization (write access required for uploads) + project = check_project_access(db, project_name, current_user, "write") + user_id = current_user.username if current_user else get_user_id_from_request(request, db, current_user) package = ( db.query(Package) @@ -2312,6 +2315,7 @@ def download_artifact( request: Request, db: Session = Depends(get_db), storage: S3Storage = Depends(get_storage), + current_user: Optional[User] = Depends(get_current_user_optional), range: Optional[str] = Header(None), mode: Optional[Literal["proxy", "redirect", "presigned"]] = Query( default=None, @@ -2347,10 +2351,8 @@ def download_artifact( """ settings = get_settings() - # Get project and package - project = db.query(Project).filter(Project.name == project_name).first() - if not project: - raise HTTPException(status_code=404, detail="Project not found") + # Check authorization (read access required for downloads) + project = check_project_access(db, project_name, current_user, "read") package = ( db.query(Package) @@ -2568,10 +2570,8 @@ def get_artifact_url( """ settings = get_settings() - # Get project and package - project = db.query(Project).filter(Project.name == project_name).first() - if not project: - raise HTTPException(status_code=404, detail="Project not found") + # Check authorization (read access required for downloads) + project = check_project_access(db, project_name, current_user, "read") package = ( db.query(Package) @@ -2826,12 +2826,11 @@ def create_tag( tag: TagCreate, request: Request, db: Session = Depends(get_db), + current_user: Optional[User] = Depends(get_current_user_optional), ): - user_id = get_user_id(request) - - project = db.query(Project).filter(Project.name == project_name).first() - if not project: - raise HTTPException(status_code=404, detail="Project not found") + """Create or update a tag. Requires write access.""" + project = check_project_access(db, project_name, current_user, "write") + user_id = current_user.username if current_user else get_user_id(request) package = ( db.query(Package) diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index 34111d8..9064602 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -182,9 +182,10 @@ def test_app(): @pytest.fixture def integration_client(): """ - Create a test client for integration tests. + Create an authenticated test client for integration tests. Uses the real database and MinIO from docker-compose.local.yml. + Authenticates as admin for write operations. """ from httpx import Client @@ -192,6 +193,15 @@ def integration_client(): base_url = os.environ.get("ORCHARD_TEST_URL", "http://localhost:8080") with Client(base_url=base_url, timeout=30.0) as client: + # Login as admin to enable write operations + login_response = client.post( + "/api/v1/auth/login", + json={"username": "admin", "password": "changeme123"}, + ) + # If login fails, tests will fail - that's expected if auth is broken + if login_response.status_code != 200: + # Try to continue without auth for backward compatibility + pass yield client diff --git a/backend/tests/integration/test_projects_api.py b/backend/tests/integration/test_projects_api.py index 0de9554..49ed5c4 100644 --- a/backend/tests/integration/test_projects_api.py +++ b/backend/tests/integration/test_projects_api.py @@ -59,7 +59,8 @@ class TestProjectCRUD: @pytest.mark.integration def test_list_projects(self, integration_client, test_project): """Test listing projects includes created project.""" - response = integration_client.get("/api/v1/projects") + # Search specifically for our test project to avoid pagination issues + response = integration_client.get(f"/api/v1/projects?search={test_project}") assert response.status_code == 200 data = response.json() @@ -107,9 +108,11 @@ class TestProjectListingFilters: @pytest.mark.integration def test_projects_search(self, integration_client, test_project): """Test project search by name.""" - # Search for our test project + # Search using the unique portion of our test project name + # test_project format is "test-project-test-{uuid[:8]}" + unique_part = test_project.split("-")[-1] # Get the UUID portion response = integration_client.get( - f"/api/v1/projects?search={test_project[:10]}" + f"/api/v1/projects?search={unique_part}" ) assert response.status_code == 200 diff --git a/backend/tests/integration/test_upload_download_api.py b/backend/tests/integration/test_upload_download_api.py index 8b83e02..4d9b8b2 100644 --- a/backend/tests/integration/test_upload_download_api.py +++ b/backend/tests/integration/test_upload_download_api.py @@ -286,6 +286,14 @@ class TestConcurrentUploads: expected_hash = compute_sha256(content) num_concurrent = 5 + # Create an API key for worker threads + api_key_response = integration_client.post( + "/api/v1/auth/keys", + json={"name": "concurrent-test-key"}, + ) + assert api_key_response.status_code == 200, f"Failed to create API key: {api_key_response.text}" + api_key = api_key_response.json()["key"] + results = [] errors = [] @@ -306,6 +314,7 @@ class TestConcurrentUploads: f"/api/v1/project/{project}/{package}/upload", files=files, data={"tag": f"concurrent-{tag_suffix}"}, + headers={"Authorization": f"Bearer {api_key}"}, ) if response.status_code == 200: results.append(response.json())