""" Service for artifact reference counting and cleanup. """ from typing import List, Optional, Tuple from sqlalchemy.orm import Session import logging from ..models import Artifact, Tag, Upload, Package from ..repositories.artifact import ArtifactRepository from ..repositories.tag import TagRepository from ..storage import S3Storage logger = logging.getLogger(__name__) class ArtifactCleanupService: """ Service for managing artifact reference counts and cleaning up orphaned artifacts. Reference counting rules: - ref_count starts at 1 when artifact is first uploaded - ref_count increments when the same artifact is uploaded again (deduplication) - ref_count decrements when a tag is deleted or updated to point elsewhere - ref_count decrements when a package is deleted (for each tag pointing to artifact) - When ref_count reaches 0, artifact is a candidate for deletion from S3 """ def __init__(self, db: Session, storage: Optional[S3Storage] = None): self.db = db self.storage = storage self.artifact_repo = ArtifactRepository(db) self.tag_repo = TagRepository(db) def on_tag_deleted(self, artifact_id: str) -> Artifact: """ Called when a tag is deleted. Decrements ref_count for the artifact the tag was pointing to. """ artifact = self.artifact_repo.get_by_sha256(artifact_id) if artifact: artifact = self.artifact_repo.decrement_ref_count(artifact) logger.info(f"Decremented ref_count for artifact {artifact_id}: now {artifact.ref_count}") return artifact def on_tag_updated(self, old_artifact_id: str, new_artifact_id: str) -> Tuple[Optional[Artifact], Optional[Artifact]]: """ Called when a tag is updated to point to a different artifact. Decrements ref_count for old artifact, increments for new (if different). Returns (old_artifact, new_artifact) tuple. """ old_artifact = None new_artifact = None if old_artifact_id != new_artifact_id: # Decrement old artifact ref_count old_artifact = self.artifact_repo.get_by_sha256(old_artifact_id) if old_artifact: old_artifact = self.artifact_repo.decrement_ref_count(old_artifact) logger.info(f"Decremented ref_count for old artifact {old_artifact_id}: now {old_artifact.ref_count}") # Increment new artifact ref_count new_artifact = self.artifact_repo.get_by_sha256(new_artifact_id) if new_artifact: new_artifact = self.artifact_repo.increment_ref_count(new_artifact) logger.info(f"Incremented ref_count for new artifact {new_artifact_id}: now {new_artifact.ref_count}") return old_artifact, new_artifact def on_package_deleted(self, package_id) -> List[str]: """ Called when a package is deleted. Decrements ref_count for all artifacts that had tags in the package. Returns list of artifact IDs that were affected. """ # Get all tags in the package before deletion tags = self.db.query(Tag).filter(Tag.package_id == package_id).all() affected_artifacts = [] for tag in tags: artifact = self.artifact_repo.get_by_sha256(tag.artifact_id) if artifact: self.artifact_repo.decrement_ref_count(artifact) affected_artifacts.append(tag.artifact_id) logger.info(f"Decremented ref_count for artifact {tag.artifact_id} (package delete)") return affected_artifacts def cleanup_orphaned_artifacts(self, batch_size: int = 100, dry_run: bool = False) -> List[str]: """ Find and delete artifacts with ref_count = 0. Args: batch_size: Maximum number of artifacts to process dry_run: If True, only report what would be deleted without actually deleting Returns: List of artifact IDs that were (or would be) deleted """ orphaned = self.artifact_repo.get_orphaned_artifacts(limit=batch_size) deleted_ids = [] for artifact in orphaned: if dry_run: logger.info(f"[DRY RUN] Would delete orphaned artifact: {artifact.id}") deleted_ids.append(artifact.id) else: try: # Delete from S3 first if self.storage: self.storage.delete(artifact.s3_key) logger.info(f"Deleted artifact from S3: {artifact.s3_key}") # Then delete from database self.artifact_repo.delete(artifact) deleted_ids.append(artifact.id) logger.info(f"Deleted orphaned artifact from database: {artifact.id}") except Exception as e: logger.error(f"Failed to delete artifact {artifact.id}: {e}") if not dry_run and deleted_ids: self.db.commit() return deleted_ids def get_orphaned_count(self) -> int: """Get count of artifacts with ref_count = 0.""" from sqlalchemy import func return ( self.db.query(func.count(Artifact.id)) .filter(Artifact.ref_count == 0) .scalar() or 0 ) def verify_ref_counts(self, fix: bool = False) -> List[dict]: """ Verify that ref_counts match actual tag references. Args: fix: If True, fix any mismatched ref_counts Returns: List of artifacts with mismatched ref_counts """ from sqlalchemy import func # Get actual tag counts per artifact tag_counts = ( self.db.query(Tag.artifact_id, func.count(Tag.id).label("tag_count")) .group_by(Tag.artifact_id) .all() ) tag_count_map = {artifact_id: count for artifact_id, count in tag_counts} # Check all artifacts artifacts = self.db.query(Artifact).all() mismatches = [] for artifact in artifacts: actual_count = tag_count_map.get(artifact.id, 0) # ref_count should be at least 1 (initial upload) + additional uploads # But tags are the primary reference, so we check against tag count if artifact.ref_count < actual_count: mismatch = { "artifact_id": artifact.id, "stored_ref_count": artifact.ref_count, "actual_tag_count": actual_count, } mismatches.append(mismatch) if fix: artifact.ref_count = max(actual_count, 1) logger.warning(f"Fixed ref_count for artifact {artifact.id}: {mismatch['stored_ref_count']} -> {artifact.ref_count}") if fix and mismatches: self.db.commit() return mismatches