""" Database utilities for optimized artifact operations. Provides batch operations to eliminate N+1 queries. """ import logging from typing import Optional from sqlalchemy.dialects.postgresql import insert as pg_insert from sqlalchemy.orm import Session from .models import Artifact, ArtifactDependency, CachedUrl logger = logging.getLogger(__name__) class ArtifactRepository: """ Optimized database operations for artifact storage. Key optimizations: - Atomic upserts using ON CONFLICT - Batch inserts for dependencies - Joined queries to avoid N+1 """ def __init__(self, db: Session): self.db = db @staticmethod def _format_dependency_values( artifact_id: str, dependencies: list[tuple[str, str, str]], ) -> list[dict]: """ Format dependencies for batch insert. Args: artifact_id: SHA256 of the artifact dependencies: List of (project, package, version_constraint) Returns: List of dicts ready for bulk insert. """ return [ { "artifact_id": artifact_id, "dependency_project": proj, "dependency_package": pkg, "version_constraint": ver, } for proj, pkg, ver in dependencies ] def get_or_create_artifact( self, sha256: str, size: int, filename: str, content_type: Optional[str] = None, created_by: str = "system", s3_key: Optional[str] = None, ) -> tuple[Artifact, bool]: """ Get existing artifact or create new one atomically. Uses INSERT ... ON CONFLICT DO UPDATE to handle races. If artifact exists, increments ref_count. Args: sha256: Content hash (primary key) size: File size in bytes filename: Original filename content_type: MIME type created_by: User who created the artifact s3_key: S3 storage key (defaults to standard path) Returns: (artifact, created) tuple where created is True for new artifacts. """ if s3_key is None: s3_key = f"fruits/{sha256[:2]}/{sha256[2:4]}/{sha256}" stmt = pg_insert(Artifact).values( id=sha256, size=size, original_name=filename, content_type=content_type, ref_count=1, created_by=created_by, s3_key=s3_key, ).on_conflict_do_update( index_elements=['id'], set_={'ref_count': Artifact.ref_count + 1} ).returning(Artifact) result = self.db.execute(stmt) artifact = result.scalar_one() # Check if this was an insert or update by comparing ref_count # ref_count=1 means new, >1 means existing created = artifact.ref_count == 1 return artifact, created def batch_upsert_dependencies( self, artifact_id: str, dependencies: list[tuple[str, str, str]], ) -> int: """ Insert dependencies in a single batch operation. Uses ON CONFLICT DO NOTHING to skip duplicates. Args: artifact_id: SHA256 of the artifact dependencies: List of (project, package, version_constraint) Returns: Number of dependencies inserted. """ if not dependencies: return 0 values = self._format_dependency_values(artifact_id, dependencies) stmt = pg_insert(ArtifactDependency).values(values) stmt = stmt.on_conflict_do_nothing( index_elements=['artifact_id', 'dependency_project', 'dependency_package'] ) result = self.db.execute(stmt) return result.rowcount def get_cached_url_with_artifact( self, url_hash: str, ) -> Optional[tuple[CachedUrl, Artifact]]: """ Get cached URL and its artifact in a single query. Args: url_hash: SHA256 of the URL Returns: (CachedUrl, Artifact) tuple or None if not found. """ result = ( self.db.query(CachedUrl, Artifact) .join(Artifact, CachedUrl.artifact_id == Artifact.id) .filter(CachedUrl.url_hash == url_hash) .first() ) return result def get_artifact_dependencies( self, artifact_id: str, ) -> list[ArtifactDependency]: """ Get all dependencies for an artifact in a single query. Args: artifact_id: SHA256 of the artifact Returns: List of ArtifactDependency objects. """ return ( self.db.query(ArtifactDependency) .filter(ArtifactDependency.artifact_id == artifact_id) .all() )