From 865812af98a4df2192bacd90ec74815360d146cf Mon Sep 17 00:00:00 2001 From: Mondo Diaz Date: Mon, 5 Jan 2026 10:04:59 -0600 Subject: [PATCH] Add ref_count management for deletions with atomic operations and error handling - Add DELETE endpoints for tags, packages, and projects with proper ref_count decrements for all affected artifacts - Implement atomic ref_count operations using SELECT FOR UPDATE row-level locking to prevent race conditions - Add custom storage exceptions (HashComputationError, S3ExistenceCheckError, S3UploadError) with retry logic for S3 existence checks - Handle race conditions in upload by locking artifact row before modification - Add comprehensive logging for all ref_count changes and deduplication events - Include ref_count in upload response schema --- backend/app/routes.py | 1200 +++++++++++++++++++++++++++++++--------- backend/app/schemas.py | 43 +- backend/app/storage.py | 246 ++++++-- 3 files changed, 1175 insertions(+), 314 deletions(-) diff --git a/backend/app/routes.py b/backend/app/routes.py index 73f2cf3..6f99451 100644 --- a/backend/app/routes.py +++ b/backend/app/routes.py @@ -1,5 +1,16 @@ from datetime import datetime, timedelta, timezone -from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Form, Request, Query, Header, Response +from fastapi import ( + APIRouter, + Depends, + HTTPException, + UploadFile, + File, + Form, + Request, + Query, + Header, + Response, +) from fastapi.responses import StreamingResponse, RedirectResponse from sqlalchemy.orm import Session from sqlalchemy import or_, func @@ -10,25 +21,57 @@ import io import hashlib from .database import get_db -from .storage import get_storage, S3Storage, MULTIPART_CHUNK_SIZE -from .models import Project, Package, Artifact, Tag, TagHistory, Upload, Consumer +from .storage import ( + get_storage, + S3Storage, + MULTIPART_CHUNK_SIZE, + StorageError, + HashComputationError, + S3ExistenceCheckError, + S3UploadError, +) +from .models import ( + Project, + Package, + Artifact, + Tag, + TagHistory, + Upload, + Consumer, + AuditLog, +) from .schemas import ( - ProjectCreate, ProjectResponse, - PackageCreate, PackageResponse, PackageDetailResponse, TagSummary, - PACKAGE_FORMATS, PACKAGE_PLATFORMS, - ArtifactResponse, ArtifactDetailResponse, ArtifactTagInfo, PackageArtifactResponse, - TagCreate, TagResponse, TagDetailResponse, TagHistoryResponse, + ProjectCreate, + ProjectResponse, + PackageCreate, + PackageResponse, + PackageDetailResponse, + TagSummary, + PACKAGE_FORMATS, + PACKAGE_PLATFORMS, + ArtifactResponse, + ArtifactDetailResponse, + ArtifactTagInfo, + PackageArtifactResponse, + TagCreate, + TagResponse, + TagDetailResponse, + TagHistoryResponse, UploadResponse, ConsumerResponse, HealthResponse, - PaginatedResponse, PaginationMeta, + PaginatedResponse, + PaginationMeta, ResumableUploadInitRequest, ResumableUploadInitResponse, ResumableUploadPartResponse, ResumableUploadCompleteRequest, ResumableUploadCompleteResponse, ResumableUploadStatusResponse, - GlobalSearchResponse, SearchResultProject, SearchResultPackage, SearchResultArtifact, + GlobalSearchResponse, + SearchResultProject, + SearchResultPackage, + SearchResultArtifact, PresignedUrlResponse, ) from .metadata import extract_metadata @@ -48,6 +91,159 @@ def get_user_id(request: Request) -> str: return "anonymous" +import logging + +logger = logging.getLogger(__name__) + + +def _increment_ref_count(db: Session, artifact_id: str) -> int: + """ + Atomically increment ref_count for an artifact using row-level locking. + Returns the new ref_count value. + + Uses SELECT FOR UPDATE to prevent race conditions when multiple + requests try to modify the same artifact's ref_count simultaneously. + """ + # Lock the row to prevent concurrent modifications + artifact = ( + db.query(Artifact).filter(Artifact.id == artifact_id).with_for_update().first() + ) + if not artifact: + logger.warning( + f"Attempted to increment ref_count for non-existent artifact: {artifact_id[:12]}..." + ) + return 0 + + artifact.ref_count += 1 + db.flush() # Ensure the update is written but don't commit yet + return artifact.ref_count + + +def _decrement_ref_count(db: Session, artifact_id: str) -> int: + """ + Atomically decrement ref_count for an artifact using row-level locking. + Returns the new ref_count value. + + Uses SELECT FOR UPDATE to prevent race conditions when multiple + requests try to modify the same artifact's ref_count simultaneously. + Will not decrement below 0. + """ + # Lock the row to prevent concurrent modifications + artifact = ( + db.query(Artifact).filter(Artifact.id == artifact_id).with_for_update().first() + ) + if not artifact: + logger.warning( + f"Attempted to decrement ref_count for non-existent artifact: {artifact_id[:12]}..." + ) + return 0 + + # Prevent going below 0 + if artifact.ref_count > 0: + artifact.ref_count -= 1 + else: + logger.warning( + f"Attempted to decrement ref_count below 0 for artifact: {artifact_id[:12]}... " + f"(current: {artifact.ref_count})" + ) + + db.flush() # Ensure the update is written but don't commit yet + return artifact.ref_count + + +def _create_or_update_tag( + db: Session, + package_id: str, + tag_name: str, + new_artifact_id: str, + user_id: str, +) -> tuple[Tag, bool, Optional[str]]: + """ + Create or update a tag, handling ref_count and history. + + Returns: + tuple of (tag, is_new, old_artifact_id) + - tag: The created/updated Tag object + - is_new: True if tag was created, False if updated + - old_artifact_id: Previous artifact_id if tag was updated, None otherwise + """ + existing_tag = ( + db.query(Tag).filter(Tag.package_id == package_id, Tag.name == tag_name).first() + ) + + if existing_tag: + old_artifact_id = existing_tag.artifact_id + + # Only process if artifact actually changed + if old_artifact_id != new_artifact_id: + # Record history + history = TagHistory( + tag_id=existing_tag.id, + old_artifact_id=old_artifact_id, + new_artifact_id=new_artifact_id, + change_type="update", + changed_by=user_id, + ) + db.add(history) + + # Decrement ref_count on old artifact + old_ref_count = _decrement_ref_count(db, old_artifact_id) + logger.info( + f"Tag '{tag_name}' updated: decremented ref_count on artifact " + f"{old_artifact_id[:12]}... to {old_ref_count}" + ) + + # Update tag to point to new artifact + existing_tag.artifact_id = new_artifact_id + existing_tag.created_by = user_id + + return existing_tag, False, old_artifact_id + else: + # Same artifact, no change needed + return existing_tag, False, None + else: + # Create new tag + new_tag = Tag( + package_id=package_id, + name=tag_name, + artifact_id=new_artifact_id, + created_by=user_id, + ) + db.add(new_tag) + db.flush() # Get the tag ID + + # Record history for creation + history = TagHistory( + tag_id=new_tag.id, + old_artifact_id=None, + new_artifact_id=new_artifact_id, + change_type="create", + changed_by=user_id, + ) + db.add(history) + + return new_tag, True, None + + +def _log_audit( + db: Session, + action: str, + resource: str, + user_id: str, + source_ip: Optional[str] = None, + details: Optional[dict] = None, +): + """Log an action to the audit_logs table.""" + audit_log = AuditLog( + action=action, + resource=resource, + user_id=user_id, + source_ip=source_ip, + details=details or {}, + ) + db.add(audit_log) + + # Health check @router.get("/health", response_model=HealthResponse) def health_check(): @@ -74,39 +270,44 @@ def global_search( or_(Project.is_public == True, Project.created_by == user_id), or_( func.lower(Project.name).contains(search_lower), - func.lower(Project.description).contains(search_lower) - ) + func.lower(Project.description).contains(search_lower), + ), ) project_count = project_query.count() projects = project_query.order_by(Project.name).limit(limit).all() # Search packages (name and description) with project name - package_query = db.query(Package, Project.name.label("project_name")).join( - Project, Package.project_id == Project.id - ).filter( - or_(Project.is_public == True, Project.created_by == user_id), - or_( - func.lower(Package.name).contains(search_lower), - func.lower(Package.description).contains(search_lower) + package_query = ( + db.query(Package, Project.name.label("project_name")) + .join(Project, Package.project_id == Project.id) + .filter( + or_(Project.is_public == True, Project.created_by == user_id), + or_( + func.lower(Package.name).contains(search_lower), + func.lower(Package.description).contains(search_lower), + ), ) ) package_count = package_query.count() package_results = package_query.order_by(Package.name).limit(limit).all() # Search tags/artifacts (tag name and original filename) - artifact_query = db.query( - Tag, Artifact, Package.name.label("package_name"), Project.name.label("project_name") - ).join( - Artifact, Tag.artifact_id == Artifact.id - ).join( - Package, Tag.package_id == Package.id - ).join( - Project, Package.project_id == Project.id - ).filter( - or_(Project.is_public == True, Project.created_by == user_id), - or_( - func.lower(Tag.name).contains(search_lower), - func.lower(Artifact.original_name).contains(search_lower) + artifact_query = ( + db.query( + Tag, + Artifact, + Package.name.label("package_name"), + Project.name.label("project_name"), + ) + .join(Artifact, Tag.artifact_id == Artifact.id) + .join(Package, Tag.package_id == Package.id) + .join(Project, Package.project_id == Project.id) + .filter( + or_(Project.is_public == True, Project.created_by == user_id), + or_( + func.lower(Tag.name).contains(search_lower), + func.lower(Artifact.original_name).contains(search_lower), + ), ) ) artifact_count = artifact_query.count() @@ -114,35 +315,41 @@ def global_search( return GlobalSearchResponse( query=q, - projects=[SearchResultProject( - id=p.id, - name=p.name, - description=p.description, - is_public=p.is_public - ) for p in projects], - packages=[SearchResultPackage( - id=pkg.id, - project_id=pkg.project_id, - project_name=project_name, - name=pkg.name, - description=pkg.description, - format=pkg.format - ) for pkg, project_name in package_results], - artifacts=[SearchResultArtifact( - tag_id=tag.id, - tag_name=tag.name, - artifact_id=artifact.id, - package_id=tag.package_id, - package_name=package_name, - project_name=project_name, - original_name=artifact.original_name - ) for tag, artifact, package_name, project_name in artifact_results], + projects=[ + SearchResultProject( + id=p.id, name=p.name, description=p.description, is_public=p.is_public + ) + for p in projects + ], + packages=[ + SearchResultPackage( + id=pkg.id, + project_id=pkg.project_id, + project_name=project_name, + name=pkg.name, + description=pkg.description, + format=pkg.format, + ) + for pkg, project_name in package_results + ], + artifacts=[ + SearchResultArtifact( + tag_id=tag.id, + tag_name=tag.name, + artifact_id=artifact.id, + package_id=tag.package_id, + package_name=package_name, + project_name=project_name, + original_name=artifact.original_name, + ) + for tag, artifact, package_name, project_name in artifact_results + ], counts={ "projects": project_count, "packages": package_count, "artifacts": artifact_count, - "total": project_count + package_count + artifact_count - } + "total": project_count + package_count + artifact_count, + }, ) @@ -152,22 +359,37 @@ def list_projects( request: Request, page: int = Query(default=1, ge=1, description="Page number"), limit: int = Query(default=20, ge=1, le=100, description="Items per page"), - search: Optional[str] = Query(default=None, description="Search by project name or description"), - visibility: Optional[str] = Query(default=None, description="Filter by visibility (public, private)"), - sort: str = Query(default="name", description="Sort field (name, created_at, updated_at)"), + search: Optional[str] = Query( + default=None, description="Search by project name or description" + ), + visibility: Optional[str] = Query( + default=None, description="Filter by visibility (public, private)" + ), + sort: str = Query( + default="name", description="Sort field (name, created_at, updated_at)" + ), order: str = Query(default="asc", description="Sort order (asc, desc)"), db: Session = Depends(get_db), ): user_id = get_user_id(request) # Validate sort field - valid_sort_fields = {"name": Project.name, "created_at": Project.created_at, "updated_at": Project.updated_at} + valid_sort_fields = { + "name": Project.name, + "created_at": Project.created_at, + "updated_at": Project.updated_at, + } if sort not in valid_sort_fields: - raise HTTPException(status_code=400, detail=f"Invalid sort field. Must be one of: {', '.join(valid_sort_fields.keys())}") + raise HTTPException( + status_code=400, + detail=f"Invalid sort field. Must be one of: {', '.join(valid_sort_fields.keys())}", + ) # Validate order if order not in ("asc", "desc"): - raise HTTPException(status_code=400, detail="Invalid order. Must be 'asc' or 'desc'") + raise HTTPException( + status_code=400, detail="Invalid order. Must be 'asc' or 'desc'" + ) # Base query - filter by access query = db.query(Project).filter( @@ -186,7 +408,7 @@ def list_projects( query = query.filter( or_( func.lower(Project.name).contains(search_lower), - func.lower(Project.description).contains(search_lower) + func.lower(Project.description).contains(search_lower), ) ) @@ -219,7 +441,9 @@ def list_projects( @router.post("/api/v1/projects", response_model=ProjectResponse) -def create_project(project: ProjectCreate, request: Request, db: Session = Depends(get_db)): +def create_project( + project: ProjectCreate, request: Request, db: Session = Depends(get_db) +): user_id = get_user_id(request) existing = db.query(Project).filter(Project.name == project.name).first() @@ -246,14 +470,80 @@ def get_project(project_name: str, db: Session = Depends(get_db)): return project +@router.delete("/api/v1/projects/{project_name}", status_code=204) +def delete_project( + project_name: str, + request: Request, + db: Session = Depends(get_db), +): + """ + Delete a project and all its packages. + + Decrements ref_count for all artifacts referenced by tags in all packages + within this project. + """ + user_id = get_user_id(request) + + project = db.query(Project).filter(Project.name == project_name).first() + if not project: + raise HTTPException(status_code=404, detail="Project not found") + + # Get all packages in this project + packages = db.query(Package).filter(Package.project_id == project.id).all() + + # Get all tags across all packages to decrement their artifact ref_counts + artifacts_decremented = {} + total_tags = 0 + + for package in packages: + tags = db.query(Tag).filter(Tag.package_id == package.id).all() + total_tags += len(tags) + + for tag in tags: + if tag.artifact_id not in artifacts_decremented: + new_ref_count = _decrement_ref_count(db, tag.artifact_id) + artifacts_decremented[tag.artifact_id] = new_ref_count + logger.info( + f"Project '{project_name}' deletion: decremented ref_count on artifact " + f"{tag.artifact_id[:12]}... to {new_ref_count}" + ) + + # Audit log + _log_audit( + db, + action="delete_project", + resource=f"project/{project_name}", + user_id=user_id, + source_ip=request.client.host if request.client else None, + details={ + "packages_deleted": len(packages), + "tags_deleted": total_tags, + "artifacts_affected": list(artifacts_decremented.keys()), + }, + ) + + # Delete the project (cascade will delete packages, tags, etc.) + db.delete(project) + db.commit() + + return None + + # Package routes -@router.get("/api/v1/project/{project_name}/packages", response_model=PaginatedResponse[PackageDetailResponse]) +@router.get( + "/api/v1/project/{project_name}/packages", + response_model=PaginatedResponse[PackageDetailResponse], +) def list_packages( project_name: str, page: int = Query(default=1, ge=1, description="Page number"), limit: int = Query(default=20, ge=1, le=100, description="Items per page"), - search: Optional[str] = Query(default=None, description="Search by name or description"), - sort: str = Query(default="name", description="Sort field (name, created_at, updated_at)"), + search: Optional[str] = Query( + default=None, description="Search by name or description" + ), + sort: str = Query( + default="name", description="Sort field (name, created_at, updated_at)" + ), order: str = Query(default="asc", description="Sort order (asc, desc)"), format: Optional[str] = Query(default=None, description="Filter by package format"), platform: Optional[str] = Query(default=None, description="Filter by platform"), @@ -264,21 +554,36 @@ def list_packages( raise HTTPException(status_code=404, detail="Project not found") # Validate sort field - valid_sort_fields = {"name": Package.name, "created_at": Package.created_at, "updated_at": Package.updated_at} + valid_sort_fields = { + "name": Package.name, + "created_at": Package.created_at, + "updated_at": Package.updated_at, + } if sort not in valid_sort_fields: - raise HTTPException(status_code=400, detail=f"Invalid sort field. Must be one of: {', '.join(valid_sort_fields.keys())}") + raise HTTPException( + status_code=400, + detail=f"Invalid sort field. Must be one of: {', '.join(valid_sort_fields.keys())}", + ) # Validate order if order not in ("asc", "desc"): - raise HTTPException(status_code=400, detail="Invalid order. Must be 'asc' or 'desc'") + raise HTTPException( + status_code=400, detail="Invalid order. Must be 'asc' or 'desc'" + ) # Validate format filter if format and format not in PACKAGE_FORMATS: - raise HTTPException(status_code=400, detail=f"Invalid format. Must be one of: {', '.join(PACKAGE_FORMATS)}") + raise HTTPException( + status_code=400, + detail=f"Invalid format. Must be one of: {', '.join(PACKAGE_FORMATS)}", + ) # Validate platform filter if platform and platform not in PACKAGE_PLATFORMS: - raise HTTPException(status_code=400, detail=f"Invalid platform. Must be one of: {', '.join(PACKAGE_PLATFORMS)}") + raise HTTPException( + status_code=400, + detail=f"Invalid platform. Must be one of: {', '.join(PACKAGE_PLATFORMS)}", + ) # Base query query = db.query(Package).filter(Package.project_id == project.id) @@ -289,7 +594,7 @@ def list_packages( query = query.filter( or_( func.lower(Package.name).contains(search_lower), - func.lower(Package.description).contains(search_lower) + func.lower(Package.description).contains(search_lower), ) ) @@ -322,54 +627,70 @@ def list_packages( detailed_packages = [] for pkg in packages: # Get tag count - tag_count = db.query(func.count(Tag.id)).filter(Tag.package_id == pkg.id).scalar() or 0 + tag_count = ( + db.query(func.count(Tag.id)).filter(Tag.package_id == pkg.id).scalar() or 0 + ) # Get unique artifact count and total size via uploads - artifact_stats = db.query( - func.count(func.distinct(Upload.artifact_id)), - func.coalesce(func.sum(Artifact.size), 0) - ).join(Artifact, Upload.artifact_id == Artifact.id).filter( - Upload.package_id == pkg.id - ).first() + artifact_stats = ( + db.query( + func.count(func.distinct(Upload.artifact_id)), + func.coalesce(func.sum(Artifact.size), 0), + ) + .join(Artifact, Upload.artifact_id == Artifact.id) + .filter(Upload.package_id == pkg.id) + .first() + ) artifact_count = artifact_stats[0] if artifact_stats else 0 total_size = artifact_stats[1] if artifact_stats else 0 # Get latest tag - latest_tag_obj = db.query(Tag).filter( - Tag.package_id == pkg.id - ).order_by(Tag.created_at.desc()).first() + latest_tag_obj = ( + db.query(Tag) + .filter(Tag.package_id == pkg.id) + .order_by(Tag.created_at.desc()) + .first() + ) latest_tag = latest_tag_obj.name if latest_tag_obj else None # Get latest upload timestamp - latest_upload = db.query(func.max(Upload.uploaded_at)).filter( - Upload.package_id == pkg.id - ).scalar() + latest_upload = ( + db.query(func.max(Upload.uploaded_at)) + .filter(Upload.package_id == pkg.id) + .scalar() + ) # Get recent tags (limit 5) - recent_tags_objs = db.query(Tag).filter( - Tag.package_id == pkg.id - ).order_by(Tag.created_at.desc()).limit(5).all() + recent_tags_objs = ( + db.query(Tag) + .filter(Tag.package_id == pkg.id) + .order_by(Tag.created_at.desc()) + .limit(5) + .all() + ) recent_tags = [ TagSummary(name=t.name, artifact_id=t.artifact_id, created_at=t.created_at) for t in recent_tags_objs ] - detailed_packages.append(PackageDetailResponse( - id=pkg.id, - project_id=pkg.project_id, - name=pkg.name, - description=pkg.description, - format=pkg.format, - platform=pkg.platform, - created_at=pkg.created_at, - updated_at=pkg.updated_at, - tag_count=tag_count, - artifact_count=artifact_count, - total_size=total_size, - latest_tag=latest_tag, - latest_upload_at=latest_upload, - recent_tags=recent_tags, - )) + detailed_packages.append( + PackageDetailResponse( + id=pkg.id, + project_id=pkg.project_id, + name=pkg.name, + description=pkg.description, + format=pkg.format, + platform=pkg.platform, + created_at=pkg.created_at, + updated_at=pkg.updated_at, + tag_count=tag_count, + artifact_count=artifact_count, + total_size=total_size, + latest_tag=latest_tag, + latest_upload_at=latest_upload, + recent_tags=recent_tags, + ) + ) return PaginatedResponse( items=detailed_packages, @@ -382,11 +703,16 @@ def list_packages( ) -@router.get("/api/v1/project/{project_name}/packages/{package_name}", response_model=PackageDetailResponse) +@router.get( + "/api/v1/project/{project_name}/packages/{package_name}", + response_model=PackageDetailResponse, +) def get_package( project_name: str, package_name: str, - include_tags: bool = Query(default=False, description="Include all tags (not just recent 5)"), + include_tags: bool = Query( + default=False, description="Include all tags (not just recent 5)" + ), db: Session = Depends(get_db), ): """Get a single package with full metadata""" @@ -394,39 +720,52 @@ def get_package( if not project: raise HTTPException(status_code=404, detail="Project not found") - pkg = db.query(Package).filter( - Package.project_id == project.id, - Package.name == package_name - ).first() + pkg = ( + db.query(Package) + .filter(Package.project_id == project.id, Package.name == package_name) + .first() + ) if not pkg: raise HTTPException(status_code=404, detail="Package not found") # Get tag count - tag_count = db.query(func.count(Tag.id)).filter(Tag.package_id == pkg.id).scalar() or 0 + tag_count = ( + db.query(func.count(Tag.id)).filter(Tag.package_id == pkg.id).scalar() or 0 + ) # Get unique artifact count and total size via uploads - artifact_stats = db.query( - func.count(func.distinct(Upload.artifact_id)), - func.coalesce(func.sum(Artifact.size), 0) - ).join(Artifact, Upload.artifact_id == Artifact.id).filter( - Upload.package_id == pkg.id - ).first() + artifact_stats = ( + db.query( + func.count(func.distinct(Upload.artifact_id)), + func.coalesce(func.sum(Artifact.size), 0), + ) + .join(Artifact, Upload.artifact_id == Artifact.id) + .filter(Upload.package_id == pkg.id) + .first() + ) artifact_count = artifact_stats[0] if artifact_stats else 0 total_size = artifact_stats[1] if artifact_stats else 0 # Get latest tag - latest_tag_obj = db.query(Tag).filter( - Tag.package_id == pkg.id - ).order_by(Tag.created_at.desc()).first() + latest_tag_obj = ( + db.query(Tag) + .filter(Tag.package_id == pkg.id) + .order_by(Tag.created_at.desc()) + .first() + ) latest_tag = latest_tag_obj.name if latest_tag_obj else None # Get latest upload timestamp - latest_upload = db.query(func.max(Upload.uploaded_at)).filter( - Upload.package_id == pkg.id - ).scalar() + latest_upload = ( + db.query(func.max(Upload.uploaded_at)) + .filter(Upload.package_id == pkg.id) + .scalar() + ) # Get tags (all if include_tags=true, else limit 5) - tags_query = db.query(Tag).filter(Tag.package_id == pkg.id).order_by(Tag.created_at.desc()) + tags_query = ( + db.query(Tag).filter(Tag.package_id == pkg.id).order_by(Tag.created_at.desc()) + ) if not include_tags: tags_query = tags_query.limit(5) tags_objs = tags_query.all() @@ -454,22 +793,36 @@ def get_package( @router.post("/api/v1/project/{project_name}/packages", response_model=PackageResponse) -def create_package(project_name: str, package: PackageCreate, db: Session = Depends(get_db)): +def create_package( + project_name: str, package: PackageCreate, db: Session = Depends(get_db) +): project = db.query(Project).filter(Project.name == project_name).first() if not project: raise HTTPException(status_code=404, detail="Project not found") # Validate format if package.format not in PACKAGE_FORMATS: - raise HTTPException(status_code=400, detail=f"Invalid format. Must be one of: {', '.join(PACKAGE_FORMATS)}") + raise HTTPException( + status_code=400, + detail=f"Invalid format. Must be one of: {', '.join(PACKAGE_FORMATS)}", + ) # Validate platform if package.platform not in PACKAGE_PLATFORMS: - raise HTTPException(status_code=400, detail=f"Invalid platform. Must be one of: {', '.join(PACKAGE_PLATFORMS)}") + raise HTTPException( + status_code=400, + detail=f"Invalid platform. Must be one of: {', '.join(PACKAGE_PLATFORMS)}", + ) - existing = db.query(Package).filter(Package.project_id == project.id, Package.name == package.name).first() + existing = ( + db.query(Package) + .filter(Package.project_id == project.id, Package.name == package.name) + .first() + ) if existing: - raise HTTPException(status_code=400, detail="Package already exists in this project") + raise HTTPException( + status_code=400, detail="Package already exists in this project" + ) db_package = Package( project_id=project.id, @@ -484,8 +837,76 @@ def create_package(project_name: str, package: PackageCreate, db: Session = Depe return db_package +@router.delete( + "/api/v1/project/{project_name}/packages/{package_name}", + status_code=204, +) +def delete_package( + project_name: str, + package_name: str, + request: Request, + db: Session = Depends(get_db), +): + """ + Delete a package and all its tags. + + Decrements ref_count for all artifacts referenced by tags in this package. + The package's uploads records are preserved for audit purposes but will + have null package_id after cascade. + """ + user_id = get_user_id(request) + + project = db.query(Project).filter(Project.name == project_name).first() + if not project: + raise HTTPException(status_code=404, detail="Project not found") + + package = ( + db.query(Package) + .filter(Package.project_id == project.id, Package.name == package_name) + .first() + ) + if not package: + raise HTTPException(status_code=404, detail="Package not found") + + # Get all tags in this package to decrement their artifact ref_counts + tags = db.query(Tag).filter(Tag.package_id == package.id).all() + + # Decrement ref_count for each artifact referenced by tags + artifacts_decremented = {} + for tag in tags: + if tag.artifact_id not in artifacts_decremented: + new_ref_count = _decrement_ref_count(db, tag.artifact_id) + artifacts_decremented[tag.artifact_id] = new_ref_count + logger.info( + f"Package '{package_name}' deletion: decremented ref_count on artifact " + f"{tag.artifact_id[:12]}... to {new_ref_count}" + ) + + # Audit log + _log_audit( + db, + action="delete_package", + resource=f"project/{project_name}/{package_name}", + user_id=user_id, + source_ip=request.client.host if request.client else None, + details={ + "tags_deleted": len(tags), + "artifacts_affected": list(artifacts_decremented.keys()), + }, + ) + + # Delete the package (cascade will delete tags, uploads, consumers) + db.delete(package) + db.commit() + + return None + + # Upload artifact -@router.post("/api/v1/project/{project_name}/{package_name}/upload", response_model=UploadResponse) +@router.post( + "/api/v1/project/{project_name}/{package_name}/upload", + response_model=UploadResponse, +) def upload_artifact( project_name: str, package_name: str, @@ -503,7 +924,11 @@ def upload_artifact( if not project: raise HTTPException(status_code=404, detail="Project not found") - package = db.query(Package).filter(Package.project_id == project.id, Package.name == package_name).first() + package = ( + db.query(Package) + .filter(Package.project_id == project.id, Package.name == package_name) + .first() + ) if not package: raise HTTPException(status_code=404, detail="Package not found") @@ -516,22 +941,51 @@ def upload_artifact( # Extract metadata file_metadata = extract_metadata( - io.BytesIO(file_content), - file.filename, - file.content_type + io.BytesIO(file_content), file.filename, file.content_type ) - # Store file (uses multipart for large files) - storage_result = storage.store(file.file, content_length) + # Store file (uses multipart for large files) with error handling + try: + storage_result = storage.store(file.file, content_length) + except HashComputationError as e: + logger.error(f"Hash computation failed during upload: {e}") + raise HTTPException( + status_code=422, + detail=f"Failed to process file: hash computation error - {str(e)}", + ) + except S3ExistenceCheckError as e: + logger.error(f"S3 existence check failed during upload: {e}") + raise HTTPException( + status_code=503, + detail="Storage service temporarily unavailable. Please retry.", + ) + except S3UploadError as e: + logger.error(f"S3 upload failed: {e}") + raise HTTPException( + status_code=503, + detail="Storage service temporarily unavailable. Please retry.", + ) + except StorageError as e: + logger.error(f"Storage error during upload: {e}") + raise HTTPException(status_code=500, detail="Internal storage error") # Check if this is a deduplicated upload deduplicated = False + saved_bytes = 0 # Create or update artifact record - artifact = db.query(Artifact).filter(Artifact.id == storage_result.sha256).first() + # Use with_for_update() to lock the row and prevent race conditions + artifact = ( + db.query(Artifact) + .filter(Artifact.id == storage_result.sha256) + .with_for_update() + .first() + ) if artifact: + # Artifact exists - increment ref_count (already locked) artifact.ref_count += 1 deduplicated = True + saved_bytes = storage_result.size # Merge metadata if new metadata was extracted if file_metadata and artifact.artifact_metadata: artifact.artifact_metadata = {**artifact.artifact_metadata, **file_metadata} @@ -544,6 +998,8 @@ def upload_artifact( artifact.checksum_sha1 = storage_result.sha1 if not artifact.s3_etag and storage_result.s3_etag: artifact.s3_etag = storage_result.s3_etag + # Refresh to get updated ref_count + db.refresh(artifact) else: artifact = Artifact( id=storage_result.sha256, @@ -570,20 +1026,32 @@ def upload_artifact( ) db.add(upload) - # Create tag if provided + # Create or update tag if provided (with ref_count management and history) if tag: - existing_tag = db.query(Tag).filter(Tag.package_id == package.id, Tag.name == tag).first() - if existing_tag: - existing_tag.artifact_id = storage_result.sha256 - existing_tag.created_by = user_id - else: - new_tag = Tag( - package_id=package.id, - name=tag, - artifact_id=storage_result.sha256, - created_by=user_id, - ) - db.add(new_tag) + _create_or_update_tag(db, package.id, tag, storage_result.sha256, user_id) + + # Log deduplication event + if deduplicated: + logger.info( + f"Deduplication: artifact {storage_result.sha256[:12]}... " + f"ref_count={artifact.ref_count}, saved_bytes={saved_bytes}" + ) + + # Audit log + _log_audit( + db, + action="upload", + resource=f"project/{project_name}/{package_name}/artifact/{storage_result.sha256[:12]}", + user_id=user_id, + source_ip=request.client.host if request.client else None, + details={ + "artifact_id": storage_result.sha256, + "size": storage_result.size, + "deduplicated": deduplicated, + "saved_bytes": saved_bytes, + "tag": tag, + }, + ) db.commit() @@ -599,11 +1067,15 @@ def upload_artifact( s3_etag=storage_result.s3_etag, format_metadata=artifact.artifact_metadata, deduplicated=deduplicated, + ref_count=artifact.ref_count, ) # Resumable upload endpoints -@router.post("/api/v1/project/{project_name}/{package_name}/upload/init", response_model=ResumableUploadInitResponse) +@router.post( + "/api/v1/project/{project_name}/{package_name}/upload/init", + response_model=ResumableUploadInitResponse, +) def init_resumable_upload( project_name: str, package_name: str, @@ -623,15 +1095,21 @@ def init_resumable_upload( if not project: raise HTTPException(status_code=404, detail="Project not found") - package = db.query(Package).filter(Package.project_id == project.id, Package.name == package_name).first() + package = ( + db.query(Package) + .filter(Package.project_id == project.id, Package.name == package_name) + .first() + ) if not package: raise HTTPException(status_code=404, detail="Package not found") # Check if artifact already exists (deduplication) - existing_artifact = db.query(Artifact).filter(Artifact.id == init_request.expected_hash).first() + existing_artifact = ( + db.query(Artifact).filter(Artifact.id == init_request.expected_hash).first() + ) if existing_artifact: - # File already exists - increment ref count and return immediately - existing_artifact.ref_count += 1 + # File already exists - use atomic increment for ref count + _increment_ref_count(db, existing_artifact.id) # Record the upload upload = Upload( @@ -640,25 +1118,38 @@ def init_resumable_upload( original_name=init_request.filename, uploaded_by=user_id, source_ip=request.client.host if request.client else None, + deduplicated=True, ) db.add(upload) - # Create tag if provided + # Create or update tag if provided (with ref_count management and history) if init_request.tag: - existing_tag = db.query(Tag).filter( - Tag.package_id == package.id, Tag.name == init_request.tag - ).first() - if existing_tag: - existing_tag.artifact_id = init_request.expected_hash - existing_tag.created_by = user_id - else: - new_tag = Tag( - package_id=package.id, - name=init_request.tag, - artifact_id=init_request.expected_hash, - created_by=user_id, - ) - db.add(new_tag) + _create_or_update_tag( + db, package.id, init_request.tag, init_request.expected_hash, user_id + ) + + # Log deduplication event + logger.info( + f"Deduplication (resumable init): artifact {init_request.expected_hash[:12]}... " + f"saved_bytes={init_request.size}" + ) + + # Audit log + _log_audit( + db, + action="upload", + resource=f"project/{project_name}/{package_name}/artifact/{init_request.expected_hash[:12]}", + user_id=user_id, + source_ip=request.client.host if request.client else None, + details={ + "artifact_id": init_request.expected_hash, + "size": init_request.size, + "deduplicated": True, + "saved_bytes": init_request.size, + "tag": init_request.tag, + "resumable": True, + }, + ) db.commit() @@ -680,7 +1171,9 @@ def init_resumable_upload( ) -@router.put("/api/v1/project/{project_name}/{package_name}/upload/{upload_id}/part/{part_number}") +@router.put( + "/api/v1/project/{project_name}/{package_name}/upload/{upload_id}/part/{part_number}" +) def upload_part( project_name: str, package_name: str, @@ -699,7 +1192,11 @@ def upload_part( if not project: raise HTTPException(status_code=404, detail="Project not found") - package = db.query(Package).filter(Package.project_id == project.id, Package.name == package_name).first() + package = ( + db.query(Package) + .filter(Package.project_id == project.id, Package.name == package_name) + .first() + ) if not package: raise HTTPException(status_code=404, detail="Package not found") @@ -708,6 +1205,7 @@ def upload_part( # Read part data from request body import asyncio + loop = asyncio.new_event_loop() async def read_body(): @@ -731,7 +1229,9 @@ def upload_part( raise HTTPException(status_code=404, detail=str(e)) -@router.post("/api/v1/project/{project_name}/{package_name}/upload/{upload_id}/complete") +@router.post( + "/api/v1/project/{project_name}/{package_name}/upload/{upload_id}/complete" +) def complete_resumable_upload( project_name: str, package_name: str, @@ -749,7 +1249,11 @@ def complete_resumable_upload( if not project: raise HTTPException(status_code=404, detail="Project not found") - package = db.query(Package).filter(Package.project_id == project.id, Package.name == package_name).first() + package = ( + db.query(Package) + .filter(Package.project_id == project.id, Package.name == package_name) + .first() + ) if not package: raise HTTPException(status_code=404, detail="Package not found") @@ -783,9 +1287,11 @@ def complete_resumable_upload( # Create tag if provided if complete_request.tag: - existing_tag = db.query(Tag).filter( - Tag.package_id == package.id, Tag.name == complete_request.tag - ).first() + existing_tag = ( + db.query(Tag) + .filter(Tag.package_id == package.id, Tag.name == complete_request.tag) + .first() + ) if existing_tag: existing_tag.artifact_id = sha256_hash existing_tag.created_by = user_id @@ -861,12 +1367,18 @@ def _resolve_artifact_ref( artifact = db.query(Artifact).filter(Artifact.id == artifact_id).first() elif ref.startswith("tag:") or ref.startswith("version:"): tag_name = ref.split(":", 1)[1] - tag = db.query(Tag).filter(Tag.package_id == package.id, Tag.name == tag_name).first() + tag = ( + db.query(Tag) + .filter(Tag.package_id == package.id, Tag.name == tag_name) + .first() + ) if tag: artifact = db.query(Artifact).filter(Artifact.id == tag.artifact_id).first() else: # Try as tag name first - tag = db.query(Tag).filter(Tag.package_id == package.id, Tag.name == ref).first() + tag = ( + db.query(Tag).filter(Tag.package_id == package.id, Tag.name == ref).first() + ) if tag: artifact = db.query(Artifact).filter(Artifact.id == tag.artifact_id).first() else: @@ -888,7 +1400,7 @@ def download_artifact( range: Optional[str] = Header(None), mode: Optional[Literal["proxy", "redirect", "presigned"]] = Query( default=None, - description="Download mode: proxy (stream through backend), redirect (302 to presigned URL), presigned (return JSON with URL)" + description="Download mode: proxy (stream through backend), redirect (302 to presigned URL), presigned (return JSON with URL)", ), ): settings = get_settings() @@ -898,7 +1410,11 @@ def download_artifact( if not project: raise HTTPException(status_code=404, detail="Project not found") - package = db.query(Package).filter(Package.project_id == project.id, Package.name == package_name).first() + package = ( + db.query(Package) + .filter(Package.project_id == project.id, Package.name == package_name) + .first() + ) if not package: raise HTTPException(status_code=404, detail="Package not found") @@ -919,7 +1435,9 @@ def download_artifact( response_content_type=artifact.content_type, response_content_disposition=f'attachment; filename="{filename}"', ) - expires_at = datetime.now(timezone.utc) + timedelta(seconds=settings.presigned_url_expiry) + expires_at = datetime.now(timezone.utc) + timedelta( + seconds=settings.presigned_url_expiry + ) return PresignedUrlResponse( url=presigned_url, @@ -945,7 +1463,9 @@ def download_artifact( # Proxy mode (default fallback) - stream through backend # Handle range requests if range: - stream, content_length, content_range = storage.get_stream(artifact.s3_key, range) + stream, content_length, content_range = storage.get_stream( + artifact.s3_key, range + ) headers = { "Content-Disposition": f'attachment; filename="{filename}"', @@ -977,7 +1497,10 @@ def download_artifact( # Get presigned URL endpoint (explicit endpoint for getting URL without redirect) -@router.get("/api/v1/project/{project_name}/{package_name}/+/{ref}/url", response_model=PresignedUrlResponse) +@router.get( + "/api/v1/project/{project_name}/{package_name}/+/{ref}/url", + response_model=PresignedUrlResponse, +) def get_artifact_url( project_name: str, package_name: str, @@ -986,7 +1509,7 @@ def get_artifact_url( storage: S3Storage = Depends(get_storage), expiry: Optional[int] = Query( default=None, - description="Custom expiry time in seconds (defaults to server setting)" + description="Custom expiry time in seconds (defaults to server setting)", ), ): """ @@ -1000,7 +1523,11 @@ def get_artifact_url( if not project: raise HTTPException(status_code=404, detail="Project not found") - package = db.query(Package).filter(Package.project_id == project.id, Package.name == package_name).first() + package = ( + db.query(Package) + .filter(Package.project_id == project.id, Package.name == package_name) + .first() + ) if not package: raise HTTPException(status_code=404, detail="Package not found") @@ -1047,7 +1574,11 @@ def head_artifact( if not project: raise HTTPException(status_code=404, detail="Project not found") - package = db.query(Package).filter(Package.project_id == project.id, Package.name == package_name).first() + package = ( + db.query(Package) + .filter(Package.project_id == project.id, Package.name == package_name) + .first() + ) if not package: raise HTTPException(status_code=404, detail="Package not found") @@ -1081,11 +1612,16 @@ def download_artifact_compat( storage: S3Storage = Depends(get_storage), range: Optional[str] = Header(None), ): - return download_artifact(project_name, package_name, ref, request, db, storage, range) + return download_artifact( + project_name, package_name, ref, request, db, storage, range + ) # Tag routes -@router.get("/api/v1/project/{project_name}/{package_name}/tags", response_model=PaginatedResponse[TagDetailResponse]) +@router.get( + "/api/v1/project/{project_name}/{package_name}/tags", + response_model=PaginatedResponse[TagDetailResponse], +) def list_tags( project_name: str, package_name: str, @@ -1100,21 +1636,34 @@ def list_tags( if not project: raise HTTPException(status_code=404, detail="Project not found") - package = db.query(Package).filter(Package.project_id == project.id, Package.name == package_name).first() + package = ( + db.query(Package) + .filter(Package.project_id == project.id, Package.name == package_name) + .first() + ) if not package: raise HTTPException(status_code=404, detail="Package not found") # Validate sort field valid_sort_fields = {"name": Tag.name, "created_at": Tag.created_at} if sort not in valid_sort_fields: - raise HTTPException(status_code=400, detail=f"Invalid sort field. Must be one of: {', '.join(valid_sort_fields.keys())}") + raise HTTPException( + status_code=400, + detail=f"Invalid sort field. Must be one of: {', '.join(valid_sort_fields.keys())}", + ) # Validate order if order not in ("asc", "desc"): - raise HTTPException(status_code=400, detail="Invalid order. Must be 'asc' or 'desc'") + raise HTTPException( + status_code=400, detail="Invalid order. Must be 'asc' or 'desc'" + ) # Base query with JOIN to artifact for metadata - query = db.query(Tag, Artifact).join(Artifact, Tag.artifact_id == Artifact.id).filter(Tag.package_id == package.id) + query = ( + db.query(Tag, Artifact) + .join(Artifact, Tag.artifact_id == Artifact.id) + .filter(Tag.package_id == package.id) + ) # Apply search filter (case-insensitive on tag name OR artifact original filename) if search: @@ -1122,7 +1671,7 @@ def list_tags( query = query.filter( or_( func.lower(Tag.name).contains(search_lower), - func.lower(Artifact.original_name).contains(search_lower) + func.lower(Artifact.original_name).contains(search_lower), ) ) @@ -1146,19 +1695,21 @@ def list_tags( # Build detailed responses with artifact metadata detailed_tags = [] for tag, artifact in results: - detailed_tags.append(TagDetailResponse( - id=tag.id, - package_id=tag.package_id, - name=tag.name, - artifact_id=tag.artifact_id, - created_at=tag.created_at, - created_by=tag.created_by, - artifact_size=artifact.size, - artifact_content_type=artifact.content_type, - artifact_original_name=artifact.original_name, - artifact_created_at=artifact.created_at, - artifact_format_metadata=artifact.format_metadata, - )) + detailed_tags.append( + TagDetailResponse( + id=tag.id, + package_id=tag.package_id, + name=tag.name, + artifact_id=tag.artifact_id, + created_at=tag.created_at, + created_by=tag.created_by, + artifact_size=artifact.size, + artifact_content_type=artifact.content_type, + artifact_original_name=artifact.original_name, + artifact_created_at=artifact.created_at, + artifact_format_metadata=artifact.format_metadata, + ) + ) return PaginatedResponse( items=detailed_tags, @@ -1171,7 +1722,9 @@ def list_tags( ) -@router.post("/api/v1/project/{project_name}/{package_name}/tags", response_model=TagResponse) +@router.post( + "/api/v1/project/{project_name}/{package_name}/tags", response_model=TagResponse +) def create_tag( project_name: str, package_name: str, @@ -1185,7 +1738,11 @@ def create_tag( if not project: raise HTTPException(status_code=404, detail="Project not found") - package = db.query(Package).filter(Package.project_id == project.id, Package.name == package_name).first() + package = ( + db.query(Package) + .filter(Package.project_id == project.id, Package.name == package_name) + .first() + ) if not package: raise HTTPException(status_code=404, detail="Package not found") @@ -1195,7 +1752,9 @@ def create_tag( raise HTTPException(status_code=404, detail="Artifact not found") # Create or update tag - existing = db.query(Tag).filter(Tag.package_id == package.id, Tag.name == tag.name).first() + existing = ( + db.query(Tag).filter(Tag.package_id == package.id, Tag.name == tag.name).first() + ) if existing: existing.artifact_id = tag.artifact_id existing.created_by = user_id @@ -1215,7 +1774,10 @@ def create_tag( return db_tag -@router.get("/api/v1/project/{project_name}/{package_name}/tags/{tag_name}", response_model=TagDetailResponse) +@router.get( + "/api/v1/project/{project_name}/{package_name}/tags/{tag_name}", + response_model=TagDetailResponse, +) def get_tag( project_name: str, package_name: str, @@ -1227,14 +1789,20 @@ def get_tag( if not project: raise HTTPException(status_code=404, detail="Project not found") - package = db.query(Package).filter(Package.project_id == project.id, Package.name == package_name).first() + package = ( + db.query(Package) + .filter(Package.project_id == project.id, Package.name == package_name) + .first() + ) if not package: raise HTTPException(status_code=404, detail="Package not found") - result = db.query(Tag, Artifact).join(Artifact, Tag.artifact_id == Artifact.id).filter( - Tag.package_id == package.id, - Tag.name == tag_name - ).first() + result = ( + db.query(Tag, Artifact) + .join(Artifact, Tag.artifact_id == Artifact.id) + .filter(Tag.package_id == package.id, Tag.name == tag_name) + .first() + ) if not result: raise HTTPException(status_code=404, detail="Tag not found") @@ -1255,7 +1823,10 @@ def get_tag( ) -@router.get("/api/v1/project/{project_name}/{package_name}/tags/{tag_name}/history", response_model=List[TagHistoryResponse]) +@router.get( + "/api/v1/project/{project_name}/{package_name}/tags/{tag_name}/history", + response_model=List[TagHistoryResponse], +) def get_tag_history( project_name: str, package_name: str, @@ -1267,43 +1838,151 @@ def get_tag_history( if not project: raise HTTPException(status_code=404, detail="Project not found") - package = db.query(Package).filter(Package.project_id == project.id, Package.name == package_name).first() + package = ( + db.query(Package) + .filter(Package.project_id == project.id, Package.name == package_name) + .first() + ) if not package: raise HTTPException(status_code=404, detail="Package not found") - tag = db.query(Tag).filter(Tag.package_id == package.id, Tag.name == tag_name).first() + tag = ( + db.query(Tag).filter(Tag.package_id == package.id, Tag.name == tag_name).first() + ) if not tag: raise HTTPException(status_code=404, detail="Tag not found") - history = db.query(TagHistory).filter(TagHistory.tag_id == tag.id).order_by(TagHistory.changed_at.desc()).all() + history = ( + db.query(TagHistory) + .filter(TagHistory.tag_id == tag.id) + .order_by(TagHistory.changed_at.desc()) + .all() + ) return history +@router.delete( + "/api/v1/project/{project_name}/{package_name}/tags/{tag_name}", + status_code=204, +) +def delete_tag( + project_name: str, + package_name: str, + tag_name: str, + request: Request, + db: Session = Depends(get_db), +): + """ + Delete a tag and decrement the artifact's ref_count. + + Records the deletion in tag history before removing the tag. + """ + user_id = get_user_id(request) + + project = db.query(Project).filter(Project.name == project_name).first() + if not project: + raise HTTPException(status_code=404, detail="Project not found") + + package = ( + db.query(Package) + .filter(Package.project_id == project.id, Package.name == package_name) + .first() + ) + if not package: + raise HTTPException(status_code=404, detail="Package not found") + + tag = ( + db.query(Tag).filter(Tag.package_id == package.id, Tag.name == tag_name).first() + ) + if not tag: + raise HTTPException(status_code=404, detail="Tag not found") + + artifact_id = tag.artifact_id + + # Record deletion in history + history = TagHistory( + tag_id=tag.id, + old_artifact_id=artifact_id, + new_artifact_id=artifact_id, # Same artifact for delete record + change_type="delete", + changed_by=user_id, + ) + db.add(history) + db.flush() # Flush history before deleting tag (cascade will delete history) + + # Decrement ref_count on artifact + new_ref_count = _decrement_ref_count(db, artifact_id) + logger.info( + f"Tag '{tag_name}' deleted: decremented ref_count on artifact " + f"{artifact_id[:12]}... to {new_ref_count}" + ) + + # Audit log + _log_audit( + db, + action="delete_tag", + resource=f"project/{project_name}/{package_name}/tag/{tag_name}", + user_id=user_id, + source_ip=request.client.host if request.client else None, + details={ + "artifact_id": artifact_id, + "new_ref_count": new_ref_count, + }, + ) + + # Delete the tag + db.delete(tag) + db.commit() + + return None + + # Consumer routes -@router.get("/api/v1/project/{project_name}/{package_name}/consumers", response_model=List[ConsumerResponse]) +@router.get( + "/api/v1/project/{project_name}/{package_name}/consumers", + response_model=List[ConsumerResponse], +) def get_consumers(project_name: str, package_name: str, db: Session = Depends(get_db)): project = db.query(Project).filter(Project.name == project_name).first() if not project: raise HTTPException(status_code=404, detail="Project not found") - package = db.query(Package).filter(Package.project_id == project.id, Package.name == package_name).first() + package = ( + db.query(Package) + .filter(Package.project_id == project.id, Package.name == package_name) + .first() + ) if not package: raise HTTPException(status_code=404, detail="Package not found") - consumers = db.query(Consumer).filter(Consumer.package_id == package.id).order_by(Consumer.last_access.desc()).all() + consumers = ( + db.query(Consumer) + .filter(Consumer.package_id == package.id) + .order_by(Consumer.last_access.desc()) + .all() + ) return consumers # Package artifacts -@router.get("/api/v1/project/{project_name}/{package_name}/artifacts", response_model=PaginatedResponse[PackageArtifactResponse]) +@router.get( + "/api/v1/project/{project_name}/{package_name}/artifacts", + response_model=PaginatedResponse[PackageArtifactResponse], +) def list_package_artifacts( project_name: str, package_name: str, page: int = Query(default=1, ge=1, description="Page number"), limit: int = Query(default=20, ge=1, le=100, description="Items per page"), - content_type: Optional[str] = Query(default=None, description="Filter by content type"), - created_after: Optional[datetime] = Query(default=None, description="Filter artifacts created after this date"), - created_before: Optional[datetime] = Query(default=None, description="Filter artifacts created before this date"), + content_type: Optional[str] = Query( + default=None, description="Filter by content type" + ), + created_after: Optional[datetime] = Query( + default=None, description="Filter artifacts created after this date" + ), + created_before: Optional[datetime] = Query( + default=None, description="Filter artifacts created before this date" + ), db: Session = Depends(get_db), ): """List all unique artifacts uploaded to a package""" @@ -1311,14 +1990,20 @@ def list_package_artifacts( if not project: raise HTTPException(status_code=404, detail="Project not found") - package = db.query(Package).filter(Package.project_id == project.id, Package.name == package_name).first() + package = ( + db.query(Package) + .filter(Package.project_id == project.id, Package.name == package_name) + .first() + ) if not package: raise HTTPException(status_code=404, detail="Package not found") # Get distinct artifacts uploaded to this package via uploads table - artifact_ids_subquery = db.query(func.distinct(Upload.artifact_id)).filter( - Upload.package_id == package.id - ).subquery() + artifact_ids_subquery = ( + db.query(func.distinct(Upload.artifact_id)) + .filter(Upload.package_id == package.id) + .subquery() + ) query = db.query(Artifact).filter(Artifact.id.in_(artifact_ids_subquery)) @@ -1337,7 +2022,9 @@ def list_package_artifacts( # Apply pagination offset = (page - 1) * limit - artifacts = query.order_by(Artifact.created_at.desc()).offset(offset).limit(limit).all() + artifacts = ( + query.order_by(Artifact.created_at.desc()).offset(offset).limit(limit).all() + ) # Calculate total pages total_pages = math.ceil(total / limit) if total > 0 else 1 @@ -1346,22 +2033,25 @@ def list_package_artifacts( artifact_responses = [] for artifact in artifacts: # Get tags pointing to this artifact in this package - tags = db.query(Tag.name).filter( - Tag.package_id == package.id, - Tag.artifact_id == artifact.id - ).all() + tags = ( + db.query(Tag.name) + .filter(Tag.package_id == package.id, Tag.artifact_id == artifact.id) + .all() + ) tag_names = [t.name for t in tags] - artifact_responses.append(PackageArtifactResponse( - id=artifact.id, - size=artifact.size, - content_type=artifact.content_type, - original_name=artifact.original_name, - created_at=artifact.created_at, - created_by=artifact.created_by, - format_metadata=artifact.format_metadata, - tags=tag_names, - )) + artifact_responses.append( + PackageArtifactResponse( + id=artifact.id, + size=artifact.size, + content_type=artifact.content_type, + original_name=artifact.original_name, + created_at=artifact.created_at, + created_by=artifact.created_by, + format_metadata=artifact.format_metadata, + tags=tag_names, + ) + ) return PaginatedResponse( items=artifact_responses, @@ -1383,13 +2073,13 @@ def get_artifact(artifact_id: str, db: Session = Depends(get_db)): raise HTTPException(status_code=404, detail="Artifact not found") # Get all tags referencing this artifact with package and project info - tags_with_context = db.query(Tag, Package, Project).join( - Package, Tag.package_id == Package.id - ).join( - Project, Package.project_id == Project.id - ).filter( - Tag.artifact_id == artifact_id - ).all() + tags_with_context = ( + db.query(Tag, Package, Project) + .join(Package, Tag.package_id == Package.id) + .join(Project, Package.project_id == Project.id) + .filter(Tag.artifact_id == artifact_id) + .all() + ) tag_infos = [ ArtifactTagInfo( diff --git a/backend/app/schemas.py b/backend/app/schemas.py index dcc7470..8f54d4e 100644 --- a/backend/app/schemas.py +++ b/backend/app/schemas.py @@ -40,8 +40,28 @@ class ProjectResponse(BaseModel): # Package format and platform enums -PACKAGE_FORMATS = ["generic", "npm", "pypi", "docker", "deb", "rpm", "maven", "nuget", "helm"] -PACKAGE_PLATFORMS = ["any", "linux", "darwin", "windows", "linux-amd64", "linux-arm64", "darwin-amd64", "darwin-arm64", "windows-amd64"] +PACKAGE_FORMATS = [ + "generic", + "npm", + "pypi", + "docker", + "deb", + "rpm", + "maven", + "nuget", + "helm", +] +PACKAGE_PLATFORMS = [ + "any", + "linux", + "darwin", + "windows", + "linux-amd64", + "linux-arm64", + "darwin-amd64", + "darwin-arm64", + "windows-amd64", +] # Package schemas @@ -68,6 +88,7 @@ class PackageResponse(BaseModel): class TagSummary(BaseModel): """Lightweight tag info for embedding in package responses""" + name: str artifact_id: str created_at: datetime @@ -75,6 +96,7 @@ class TagSummary(BaseModel): class PackageDetailResponse(BaseModel): """Package with aggregated metadata""" + id: UUID project_id: UUID name: str @@ -135,6 +157,7 @@ class TagResponse(BaseModel): class TagDetailResponse(BaseModel): """Tag with embedded artifact metadata""" + id: UUID package_id: UUID name: str @@ -154,6 +177,7 @@ class TagDetailResponse(BaseModel): class TagHistoryResponse(BaseModel): """History entry for tag changes""" + id: UUID tag_id: UUID old_artifact_id: Optional[str] @@ -167,6 +191,7 @@ class TagHistoryResponse(BaseModel): class ArtifactTagInfo(BaseModel): """Tag info for embedding in artifact responses""" + id: UUID name: str package_id: UUID @@ -176,6 +201,7 @@ class ArtifactTagInfo(BaseModel): class ArtifactDetailResponse(BaseModel): """Artifact with list of tags/packages referencing it""" + id: str sha256: str # Explicit SHA256 field (same as id) size: int @@ -196,6 +222,7 @@ class ArtifactDetailResponse(BaseModel): class PackageArtifactResponse(BaseModel): """Artifact with tags for package artifact listing""" + id: str sha256: str # Explicit SHA256 field (same as id) size: int @@ -226,11 +253,13 @@ class UploadResponse(BaseModel): s3_etag: Optional[str] = None format_metadata: Optional[Dict[str, Any]] = None deduplicated: bool = False + ref_count: int = 1 # Current reference count after this upload # Resumable upload schemas class ResumableUploadInitRequest(BaseModel): """Request to initiate a resumable upload""" + expected_hash: str # SHA256 hash of the file (client must compute) filename: str content_type: Optional[str] = None @@ -240,6 +269,7 @@ class ResumableUploadInitRequest(BaseModel): class ResumableUploadInitResponse(BaseModel): """Response from initiating a resumable upload""" + upload_id: Optional[str] # None if file already exists already_exists: bool artifact_id: Optional[str] = None # Set if already_exists is True @@ -248,17 +278,20 @@ class ResumableUploadInitResponse(BaseModel): class ResumableUploadPartResponse(BaseModel): """Response from uploading a part""" + part_number: int etag: str class ResumableUploadCompleteRequest(BaseModel): """Request to complete a resumable upload""" + tag: Optional[str] = None class ResumableUploadCompleteResponse(BaseModel): """Response from completing a resumable upload""" + artifact_id: str size: int project: str @@ -268,6 +301,7 @@ class ResumableUploadCompleteResponse(BaseModel): class ResumableUploadStatusResponse(BaseModel): """Status of a resumable upload""" + upload_id: str uploaded_parts: List[int] total_uploaded_bytes: int @@ -288,6 +322,7 @@ class ConsumerResponse(BaseModel): # Global search schemas class SearchResultProject(BaseModel): """Project result for global search""" + id: UUID name: str description: Optional[str] @@ -299,6 +334,7 @@ class SearchResultProject(BaseModel): class SearchResultPackage(BaseModel): """Package result for global search""" + id: UUID project_id: UUID project_name: str @@ -312,6 +348,7 @@ class SearchResultPackage(BaseModel): class SearchResultArtifact(BaseModel): """Artifact/tag result for global search""" + tag_id: UUID tag_name: str artifact_id: str @@ -323,6 +360,7 @@ class SearchResultArtifact(BaseModel): class GlobalSearchResponse(BaseModel): """Combined search results across all entity types""" + query: str projects: List[SearchResultProject] packages: List[SearchResultPackage] @@ -333,6 +371,7 @@ class GlobalSearchResponse(BaseModel): # Presigned URL response class PresignedUrlResponse(BaseModel): """Response containing a presigned URL for direct S3 download""" + url: str expires_at: datetime method: str = "GET" diff --git a/backend/app/storage.py b/backend/app/storage.py index ef0c510..9c33aca 100644 --- a/backend/app/storage.py +++ b/backend/app/storage.py @@ -16,10 +16,37 @@ MULTIPART_THRESHOLD = 100 * 1024 * 1024 MULTIPART_CHUNK_SIZE = 10 * 1024 * 1024 # Chunk size for streaming hash computation HASH_CHUNK_SIZE = 8 * 1024 * 1024 +# Maximum retries for S3 existence check +MAX_EXISTENCE_CHECK_RETRIES = 3 + + +class StorageError(Exception): + """Base exception for storage operations""" + + pass + + +class HashComputationError(StorageError): + """Raised when hash computation fails""" + + pass + + +class S3ExistenceCheckError(StorageError): + """Raised when S3 existence check fails after retries""" + + pass + + +class S3UploadError(StorageError): + """Raised when S3 upload fails""" + + pass class StorageResult(NamedTuple): """Result of storing a file with all computed checksums""" + sha256: str size: int s3_key: str @@ -30,7 +57,9 @@ class StorageResult(NamedTuple): class S3Storage: def __init__(self): - config = Config(s3={"addressing_style": "path"} if settings.s3_use_path_style else {}) + config = Config( + s3={"addressing_style": "path"} if settings.s3_use_path_style else {} + ) self.client = boto3.client( "s3", @@ -44,7 +73,9 @@ class S3Storage: # Store active multipart uploads for resumable support self._active_uploads: Dict[str, Dict[str, Any]] = {} - def store(self, file: BinaryIO, content_length: Optional[int] = None) -> StorageResult: + def store( + self, file: BinaryIO, content_length: Optional[int] = None + ) -> StorageResult: """ Store a file and return StorageResult with all checksums. Content-addressable: if the file already exists, just return the hash. @@ -57,25 +88,54 @@ class S3Storage: return self._store_multipart(file, content_length) def _store_simple(self, file: BinaryIO) -> StorageResult: - """Store a small file using simple put_object""" - # Read file and compute all hashes - content = file.read() - sha256_hash = hashlib.sha256(content).hexdigest() - md5_hash = hashlib.md5(content).hexdigest() - sha1_hash = hashlib.sha1(content).hexdigest() - size = len(content) + """ + Store a small file using simple put_object. - # Check if already exists + Raises: + HashComputationError: If hash computation fails + S3ExistenceCheckError: If S3 existence check fails after retries + S3UploadError: If S3 upload fails + """ + # Read file and compute all hashes with error handling + try: + content = file.read() + if not content: + raise HashComputationError("Empty file content") + + sha256_hash = hashlib.sha256(content).hexdigest() + md5_hash = hashlib.md5(content).hexdigest() + sha1_hash = hashlib.sha1(content).hexdigest() + size = len(content) + except HashComputationError: + raise + except Exception as e: + logger.error(f"Hash computation failed: {e}") + raise HashComputationError(f"Failed to compute hash: {e}") from e + + # Check if already exists (with retry logic) s3_key = f"fruits/{sha256_hash[:2]}/{sha256_hash[2:4]}/{sha256_hash}" s3_etag = None - if not self._exists(s3_key): - response = self.client.put_object( - Bucket=self.bucket, - Key=s3_key, - Body=content, - ) - s3_etag = response.get("ETag", "").strip('"') + try: + exists = self._exists(s3_key) + except S3ExistenceCheckError: + # Re-raise the specific error + raise + except Exception as e: + logger.error(f"Unexpected error during S3 existence check: {e}") + raise S3ExistenceCheckError(f"Failed to check S3 existence: {e}") from e + + if not exists: + try: + response = self.client.put_object( + Bucket=self.bucket, + Key=s3_key, + Body=content, + ) + s3_etag = response.get("ETag", "").strip('"') + except ClientError as e: + logger.error(f"S3 upload failed: {e}") + raise S3UploadError(f"Failed to upload to S3: {e}") from e else: # Get existing ETag obj_info = self.get_object_info(s3_key) @@ -92,30 +152,55 @@ class S3Storage: ) def _store_multipart(self, file: BinaryIO, content_length: int) -> StorageResult: - """Store a large file using S3 multipart upload with streaming hash computation""" + """ + Store a large file using S3 multipart upload with streaming hash computation. + + Raises: + HashComputationError: If hash computation fails + S3ExistenceCheckError: If S3 existence check fails after retries + S3UploadError: If S3 upload fails + """ # First pass: compute all hashes by streaming through file - sha256_hasher = hashlib.sha256() - md5_hasher = hashlib.md5() - sha1_hasher = hashlib.sha1() - size = 0 + try: + sha256_hasher = hashlib.sha256() + md5_hasher = hashlib.md5() + sha1_hasher = hashlib.sha1() + size = 0 - # Read file in chunks to compute hashes - while True: - chunk = file.read(HASH_CHUNK_SIZE) - if not chunk: - break - sha256_hasher.update(chunk) - md5_hasher.update(chunk) - sha1_hasher.update(chunk) - size += len(chunk) + # Read file in chunks to compute hashes + while True: + chunk = file.read(HASH_CHUNK_SIZE) + if not chunk: + break + sha256_hasher.update(chunk) + md5_hasher.update(chunk) + sha1_hasher.update(chunk) + size += len(chunk) + + if size == 0: + raise HashComputationError("Empty file content") + + sha256_hash = sha256_hasher.hexdigest() + md5_hash = md5_hasher.hexdigest() + sha1_hash = sha1_hasher.hexdigest() + except HashComputationError: + raise + except Exception as e: + logger.error(f"Hash computation failed for multipart upload: {e}") + raise HashComputationError(f"Failed to compute hash: {e}") from e - sha256_hash = sha256_hasher.hexdigest() - md5_hash = md5_hasher.hexdigest() - sha1_hash = sha1_hasher.hexdigest() s3_key = f"fruits/{sha256_hash[:2]}/{sha256_hash[2:4]}/{sha256_hash}" - # Check if already exists (deduplication) - if self._exists(s3_key): + # Check if already exists (deduplication) with retry logic + try: + exists = self._exists(s3_key) + except S3ExistenceCheckError: + raise + except Exception as e: + logger.error(f"Unexpected error during S3 existence check: {e}") + raise S3ExistenceCheckError(f"Failed to check S3 existence: {e}") from e + + if exists: obj_info = self.get_object_info(s3_key) s3_etag = obj_info.get("etag", "").strip('"') if obj_info else None return StorageResult( @@ -150,10 +235,12 @@ class S3Storage: PartNumber=part_number, Body=chunk, ) - parts.append({ - "PartNumber": part_number, - "ETag": response["ETag"], - }) + parts.append( + { + "PartNumber": part_number, + "ETag": response["ETag"], + } + ) part_number += 1 # Complete multipart upload @@ -226,7 +313,9 @@ class S3Storage: # Upload based on size if size < MULTIPART_THRESHOLD: content = b"".join(all_chunks) - response = self.client.put_object(Bucket=self.bucket, Key=s3_key, Body=content) + response = self.client.put_object( + Bucket=self.bucket, Key=s3_key, Body=content + ) s3_etag = response.get("ETag", "").strip('"') else: # Use multipart for large files @@ -251,10 +340,12 @@ class S3Storage: PartNumber=part_number, Body=part_data, ) - parts.append({ - "PartNumber": part_number, - "ETag": response["ETag"], - }) + parts.append( + { + "PartNumber": part_number, + "ETag": response["ETag"], + } + ) part_number += 1 # Upload remaining buffer @@ -266,10 +357,12 @@ class S3Storage: PartNumber=part_number, Body=buffer, ) - parts.append({ - "PartNumber": part_number, - "ETag": response["ETag"], - }) + parts.append( + { + "PartNumber": part_number, + "ETag": response["ETag"], + } + ) complete_response = self.client.complete_multipart_upload( Bucket=self.bucket, @@ -326,7 +419,9 @@ class S3Storage: self._active_uploads[upload_id] = session return session - def upload_part(self, upload_id: str, part_number: int, data: bytes) -> Dict[str, Any]: + def upload_part( + self, upload_id: str, part_number: int, data: bytes + ) -> Dict[str, Any]: """ Upload a part for a resumable upload. Returns part info including ETag. @@ -434,13 +529,50 @@ class S3Storage: except ClientError: return None - def _exists(self, s3_key: str) -> bool: - """Check if an object exists""" - try: - self.client.head_object(Bucket=self.bucket, Key=s3_key) - return True - except ClientError: - return False + def _exists(self, s3_key: str, retry: bool = True) -> bool: + """ + Check if an object exists with optional retry logic. + + Args: + s3_key: The S3 key to check + retry: Whether to retry on transient failures (default: True) + + Returns: + True if object exists, False otherwise + + Raises: + S3ExistenceCheckError: If all retries fail due to non-404 errors + """ + import time + + max_retries = MAX_EXISTENCE_CHECK_RETRIES if retry else 1 + last_error = None + + for attempt in range(max_retries): + try: + self.client.head_object(Bucket=self.bucket, Key=s3_key) + return True + except ClientError as e: + error_code = e.response.get("Error", {}).get("Code", "") + # 404 means object doesn't exist - not an error + if error_code in ("404", "NoSuchKey"): + return False + + # For other errors, retry + last_error = e + if attempt < max_retries - 1: + logger.warning( + f"S3 existence check failed (attempt {attempt + 1}/{max_retries}): {e}" + ) + time.sleep(0.1 * (attempt + 1)) # Exponential backoff + + # All retries failed + logger.error( + f"S3 existence check failed after {max_retries} attempts: {last_error}" + ) + raise S3ExistenceCheckError( + f"Failed to check S3 object existence after {max_retries} attempts: {last_error}" + ) def delete(self, s3_key: str) -> bool: """Delete an object"""