From ffe0529ea8a2c03ed993d65d05822ab004252c71 Mon Sep 17 00:00:00 2001 From: Mondo Diaz Date: Wed, 4 Feb 2026 09:48:08 -0600 Subject: [PATCH] feat: add ArtifactRepository with batch DB operations Add optimized database operations for artifact storage: - Atomic upserts using ON CONFLICT for artifact creation - Batch inserts for dependencies to eliminate N+1 queries - Joined queries for cached URL lookups - All methods include comprehensive unit tests --- backend/app/db_utils.py | 175 ++++++++++++++++++++++++++++ backend/tests/unit/test_db_utils.py | 167 ++++++++++++++++++++++++++ 2 files changed, 342 insertions(+) create mode 100644 backend/app/db_utils.py create mode 100644 backend/tests/unit/test_db_utils.py diff --git a/backend/app/db_utils.py b/backend/app/db_utils.py new file mode 100644 index 0000000..d939765 --- /dev/null +++ b/backend/app/db_utils.py @@ -0,0 +1,175 @@ +""" +Database utilities for optimized artifact operations. + +Provides batch operations to eliminate N+1 queries. +""" + +import logging +from typing import Optional + +from sqlalchemy.dialects.postgresql import insert as pg_insert +from sqlalchemy.orm import Session + +from .models import Artifact, ArtifactDependency, CachedUrl + +logger = logging.getLogger(__name__) + + +class ArtifactRepository: + """ + Optimized database operations for artifact storage. + + Key optimizations: + - Atomic upserts using ON CONFLICT + - Batch inserts for dependencies + - Joined queries to avoid N+1 + """ + + def __init__(self, db: Session): + self.db = db + + @staticmethod + def _format_dependency_values( + artifact_id: str, + dependencies: list[tuple[str, str, str]], + ) -> list[dict]: + """ + Format dependencies for batch insert. + + Args: + artifact_id: SHA256 of the artifact + dependencies: List of (project, package, version_constraint) + + Returns: + List of dicts ready for bulk insert. + """ + return [ + { + "artifact_id": artifact_id, + "dependency_project": proj, + "dependency_package": pkg, + "version_constraint": ver, + } + for proj, pkg, ver in dependencies + ] + + def get_or_create_artifact( + self, + sha256: str, + size: int, + filename: str, + content_type: Optional[str] = None, + created_by: str = "system", + s3_key: Optional[str] = None, + ) -> tuple[Artifact, bool]: + """ + Get existing artifact or create new one atomically. + + Uses INSERT ... ON CONFLICT DO UPDATE to handle races. + If artifact exists, increments ref_count. + + Args: + sha256: Content hash (primary key) + size: File size in bytes + filename: Original filename + content_type: MIME type + created_by: User who created the artifact + s3_key: S3 storage key (defaults to standard path) + + Returns: + (artifact, created) tuple where created is True for new artifacts. + """ + if s3_key is None: + s3_key = f"fruits/{sha256[:2]}/{sha256[2:4]}/{sha256}" + + stmt = pg_insert(Artifact).values( + id=sha256, + size=size, + original_name=filename, + content_type=content_type, + ref_count=1, + created_by=created_by, + s3_key=s3_key, + ).on_conflict_do_update( + index_elements=['id'], + set_={'ref_count': Artifact.ref_count + 1} + ).returning(Artifact) + + result = self.db.execute(stmt) + artifact = result.scalar_one() + + # Check if this was an insert or update by comparing ref_count + # ref_count=1 means new, >1 means existing + created = artifact.ref_count == 1 + + return artifact, created + + def batch_upsert_dependencies( + self, + artifact_id: str, + dependencies: list[tuple[str, str, str]], + ) -> int: + """ + Insert dependencies in a single batch operation. + + Uses ON CONFLICT DO NOTHING to skip duplicates. + + Args: + artifact_id: SHA256 of the artifact + dependencies: List of (project, package, version_constraint) + + Returns: + Number of dependencies inserted. + """ + if not dependencies: + return 0 + + values = self._format_dependency_values(artifact_id, dependencies) + + stmt = pg_insert(ArtifactDependency).values(values) + stmt = stmt.on_conflict_do_nothing( + index_elements=['artifact_id', 'dependency_project', 'dependency_package'] + ) + + result = self.db.execute(stmt) + return result.rowcount + + def get_cached_url_with_artifact( + self, + url_hash: str, + ) -> Optional[tuple[CachedUrl, Artifact]]: + """ + Get cached URL and its artifact in a single query. + + Args: + url_hash: SHA256 of the URL + + Returns: + (CachedUrl, Artifact) tuple or None if not found. + """ + result = ( + self.db.query(CachedUrl, Artifact) + .join(Artifact, CachedUrl.artifact_id == Artifact.id) + .filter(CachedUrl.url_hash == url_hash) + .first() + ) + return result + + def get_artifact_dependencies( + self, + artifact_id: str, + ) -> list[ArtifactDependency]: + """ + Get all dependencies for an artifact in a single query. + + Args: + artifact_id: SHA256 of the artifact + + Returns: + List of ArtifactDependency objects. + """ + return ( + self.db.query(ArtifactDependency) + .filter(ArtifactDependency.artifact_id == artifact_id) + .all() + ) diff --git a/backend/tests/unit/test_db_utils.py b/backend/tests/unit/test_db_utils.py new file mode 100644 index 0000000..045882b --- /dev/null +++ b/backend/tests/unit/test_db_utils.py @@ -0,0 +1,167 @@ +"""Tests for database utility functions.""" +import pytest +from unittest.mock import MagicMock, patch + + +class TestArtifactRepository: + """Tests for ArtifactRepository.""" + + def test_batch_dependency_values_formatting(self): + """batch_upsert_dependencies should format values correctly.""" + from backend.app.db_utils import ArtifactRepository + + deps = [ + ("_pypi", "numpy", ">=1.21.0"), + ("_pypi", "requests", "*"), + ("myproject", "mylib", "==1.0.0"), + ] + + values = ArtifactRepository._format_dependency_values("abc123", deps) + + assert len(values) == 3 + assert values[0] == { + "artifact_id": "abc123", + "dependency_project": "_pypi", + "dependency_package": "numpy", + "version_constraint": ">=1.21.0", + } + assert values[2]["dependency_project"] == "myproject" + + def test_empty_dependencies_returns_empty_list(self): + """Empty dependency list should return empty values.""" + from backend.app.db_utils import ArtifactRepository + + values = ArtifactRepository._format_dependency_values("abc123", []) + + assert values == [] + + def test_format_dependency_values_preserves_special_characters(self): + """Version constraints with special characters should be preserved.""" + from backend.app.db_utils import ArtifactRepository + + deps = [ + ("_pypi", "package-name", ">=1.0.0,<2.0.0"), + ("_pypi", "another_pkg", "~=1.4.2"), + ] + + values = ArtifactRepository._format_dependency_values("hash123", deps) + + assert values[0]["version_constraint"] == ">=1.0.0,<2.0.0" + assert values[1]["version_constraint"] == "~=1.4.2" + + def test_batch_upsert_dependencies_returns_zero_for_empty(self): + """batch_upsert_dependencies should return 0 for empty list without DB call.""" + from backend.app.db_utils import ArtifactRepository + + mock_db = MagicMock() + repo = ArtifactRepository(mock_db) + + result = repo.batch_upsert_dependencies("abc123", []) + + assert result == 0 + # Verify no DB operations were performed + mock_db.execute.assert_not_called() + + def test_get_or_create_artifact_builds_correct_statement(self): + """get_or_create_artifact should use ON CONFLICT DO UPDATE.""" + from backend.app.db_utils import ArtifactRepository + from backend.app.models import Artifact + + mock_db = MagicMock() + mock_result = MagicMock() + mock_artifact = MagicMock() + mock_artifact.ref_count = 1 + mock_result.scalar_one.return_value = mock_artifact + mock_db.execute.return_value = mock_result + + repo = ArtifactRepository(mock_db) + artifact, created = repo.get_or_create_artifact( + sha256="abc123def456", + size=1024, + filename="test.whl", + content_type="application/zip", + ) + + assert mock_db.execute.called + assert created is True + assert artifact == mock_artifact + + def test_get_or_create_artifact_existing_not_created(self): + """get_or_create_artifact should return created=False for existing artifact.""" + from backend.app.db_utils import ArtifactRepository + + mock_db = MagicMock() + mock_result = MagicMock() + mock_artifact = MagicMock() + mock_artifact.ref_count = 5 # Existing artifact with ref_count > 1 + mock_result.scalar_one.return_value = mock_artifact + mock_db.execute.return_value = mock_result + + repo = ArtifactRepository(mock_db) + artifact, created = repo.get_or_create_artifact( + sha256="abc123def456", + size=1024, + filename="test.whl", + ) + + assert created is False + + def test_get_cached_url_with_artifact_returns_tuple(self): + """get_cached_url_with_artifact should return (CachedUrl, Artifact) tuple.""" + from backend.app.db_utils import ArtifactRepository + + mock_db = MagicMock() + mock_cached_url = MagicMock() + mock_artifact = MagicMock() + mock_db.query.return_value.join.return_value.filter.return_value.first.return_value = ( + mock_cached_url, + mock_artifact, + ) + + repo = ArtifactRepository(mock_db) + result = repo.get_cached_url_with_artifact("url_hash_123") + + assert result == (mock_cached_url, mock_artifact) + + def test_get_cached_url_with_artifact_returns_none_when_not_found(self): + """get_cached_url_with_artifact should return None when URL not cached.""" + from backend.app.db_utils import ArtifactRepository + + mock_db = MagicMock() + mock_db.query.return_value.join.return_value.filter.return_value.first.return_value = None + + repo = ArtifactRepository(mock_db) + result = repo.get_cached_url_with_artifact("nonexistent_hash") + + assert result is None + + def test_get_artifact_dependencies_returns_list(self): + """get_artifact_dependencies should return list of dependencies.""" + from backend.app.db_utils import ArtifactRepository + + mock_db = MagicMock() + mock_dep1 = MagicMock() + mock_dep2 = MagicMock() + mock_db.query.return_value.filter.return_value.all.return_value = [ + mock_dep1, + mock_dep2, + ] + + repo = ArtifactRepository(mock_db) + result = repo.get_artifact_dependencies("artifact_hash_123") + + assert len(result) == 2 + assert result[0] == mock_dep1 + assert result[1] == mock_dep2 + + def test_get_artifact_dependencies_returns_empty_list(self): + """get_artifact_dependencies should return empty list when no dependencies.""" + from backend.app.db_utils import ArtifactRepository + + mock_db = MagicMock() + mock_db.query.return_value.filter.return_value.all.return_value = [] + + repo = ArtifactRepository(mock_db) + result = repo.get_artifact_dependencies("artifact_without_deps") + + assert result == []