feat: add ArtifactRepository with batch DB operations
Add optimized database operations for artifact storage: - Atomic upserts using ON CONFLICT for artifact creation - Batch inserts for dependencies to eliminate N+1 queries - Joined queries for cached URL lookups - All methods include comprehensive unit tests
This commit is contained in:
175
backend/app/db_utils.py
Normal file
175
backend/app/db_utils.py
Normal file
@@ -0,0 +1,175 @@
|
|||||||
|
"""
|
||||||
|
Database utilities for optimized artifact operations.
|
||||||
|
|
||||||
|
Provides batch operations to eliminate N+1 queries.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from .models import Artifact, ArtifactDependency, CachedUrl
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ArtifactRepository:
|
||||||
|
"""
|
||||||
|
Optimized database operations for artifact storage.
|
||||||
|
|
||||||
|
Key optimizations:
|
||||||
|
- Atomic upserts using ON CONFLICT
|
||||||
|
- Batch inserts for dependencies
|
||||||
|
- Joined queries to avoid N+1
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, db: Session):
|
||||||
|
self.db = db
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _format_dependency_values(
|
||||||
|
artifact_id: str,
|
||||||
|
dependencies: list[tuple[str, str, str]],
|
||||||
|
) -> list[dict]:
|
||||||
|
"""
|
||||||
|
Format dependencies for batch insert.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
artifact_id: SHA256 of the artifact
|
||||||
|
dependencies: List of (project, package, version_constraint)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of dicts ready for bulk insert.
|
||||||
|
"""
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"artifact_id": artifact_id,
|
||||||
|
"dependency_project": proj,
|
||||||
|
"dependency_package": pkg,
|
||||||
|
"version_constraint": ver,
|
||||||
|
}
|
||||||
|
for proj, pkg, ver in dependencies
|
||||||
|
]
|
||||||
|
|
||||||
|
def get_or_create_artifact(
|
||||||
|
self,
|
||||||
|
sha256: str,
|
||||||
|
size: int,
|
||||||
|
filename: str,
|
||||||
|
content_type: Optional[str] = None,
|
||||||
|
created_by: str = "system",
|
||||||
|
s3_key: Optional[str] = None,
|
||||||
|
) -> tuple[Artifact, bool]:
|
||||||
|
"""
|
||||||
|
Get existing artifact or create new one atomically.
|
||||||
|
|
||||||
|
Uses INSERT ... ON CONFLICT DO UPDATE to handle races.
|
||||||
|
If artifact exists, increments ref_count.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sha256: Content hash (primary key)
|
||||||
|
size: File size in bytes
|
||||||
|
filename: Original filename
|
||||||
|
content_type: MIME type
|
||||||
|
created_by: User who created the artifact
|
||||||
|
s3_key: S3 storage key (defaults to standard path)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(artifact, created) tuple where created is True for new artifacts.
|
||||||
|
"""
|
||||||
|
if s3_key is None:
|
||||||
|
s3_key = f"fruits/{sha256[:2]}/{sha256[2:4]}/{sha256}"
|
||||||
|
|
||||||
|
stmt = pg_insert(Artifact).values(
|
||||||
|
id=sha256,
|
||||||
|
size=size,
|
||||||
|
original_name=filename,
|
||||||
|
content_type=content_type,
|
||||||
|
ref_count=1,
|
||||||
|
created_by=created_by,
|
||||||
|
s3_key=s3_key,
|
||||||
|
).on_conflict_do_update(
|
||||||
|
index_elements=['id'],
|
||||||
|
set_={'ref_count': Artifact.ref_count + 1}
|
||||||
|
).returning(Artifact)
|
||||||
|
|
||||||
|
result = self.db.execute(stmt)
|
||||||
|
artifact = result.scalar_one()
|
||||||
|
|
||||||
|
# Check if this was an insert or update by comparing ref_count
|
||||||
|
# ref_count=1 means new, >1 means existing
|
||||||
|
created = artifact.ref_count == 1
|
||||||
|
|
||||||
|
return artifact, created
|
||||||
|
|
||||||
|
def batch_upsert_dependencies(
|
||||||
|
self,
|
||||||
|
artifact_id: str,
|
||||||
|
dependencies: list[tuple[str, str, str]],
|
||||||
|
) -> int:
|
||||||
|
"""
|
||||||
|
Insert dependencies in a single batch operation.
|
||||||
|
|
||||||
|
Uses ON CONFLICT DO NOTHING to skip duplicates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
artifact_id: SHA256 of the artifact
|
||||||
|
dependencies: List of (project, package, version_constraint)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of dependencies inserted.
|
||||||
|
"""
|
||||||
|
if not dependencies:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
values = self._format_dependency_values(artifact_id, dependencies)
|
||||||
|
|
||||||
|
stmt = pg_insert(ArtifactDependency).values(values)
|
||||||
|
stmt = stmt.on_conflict_do_nothing(
|
||||||
|
index_elements=['artifact_id', 'dependency_project', 'dependency_package']
|
||||||
|
)
|
||||||
|
|
||||||
|
result = self.db.execute(stmt)
|
||||||
|
return result.rowcount
|
||||||
|
|
||||||
|
def get_cached_url_with_artifact(
|
||||||
|
self,
|
||||||
|
url_hash: str,
|
||||||
|
) -> Optional[tuple[CachedUrl, Artifact]]:
|
||||||
|
"""
|
||||||
|
Get cached URL and its artifact in a single query.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url_hash: SHA256 of the URL
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(CachedUrl, Artifact) tuple or None if not found.
|
||||||
|
"""
|
||||||
|
result = (
|
||||||
|
self.db.query(CachedUrl, Artifact)
|
||||||
|
.join(Artifact, CachedUrl.artifact_id == Artifact.id)
|
||||||
|
.filter(CachedUrl.url_hash == url_hash)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def get_artifact_dependencies(
|
||||||
|
self,
|
||||||
|
artifact_id: str,
|
||||||
|
) -> list[ArtifactDependency]:
|
||||||
|
"""
|
||||||
|
Get all dependencies for an artifact in a single query.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
artifact_id: SHA256 of the artifact
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of ArtifactDependency objects.
|
||||||
|
"""
|
||||||
|
return (
|
||||||
|
self.db.query(ArtifactDependency)
|
||||||
|
.filter(ArtifactDependency.artifact_id == artifact_id)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
167
backend/tests/unit/test_db_utils.py
Normal file
167
backend/tests/unit/test_db_utils.py
Normal file
@@ -0,0 +1,167 @@
|
|||||||
|
"""Tests for database utility functions."""
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
|
||||||
|
class TestArtifactRepository:
|
||||||
|
"""Tests for ArtifactRepository."""
|
||||||
|
|
||||||
|
def test_batch_dependency_values_formatting(self):
|
||||||
|
"""batch_upsert_dependencies should format values correctly."""
|
||||||
|
from backend.app.db_utils import ArtifactRepository
|
||||||
|
|
||||||
|
deps = [
|
||||||
|
("_pypi", "numpy", ">=1.21.0"),
|
||||||
|
("_pypi", "requests", "*"),
|
||||||
|
("myproject", "mylib", "==1.0.0"),
|
||||||
|
]
|
||||||
|
|
||||||
|
values = ArtifactRepository._format_dependency_values("abc123", deps)
|
||||||
|
|
||||||
|
assert len(values) == 3
|
||||||
|
assert values[0] == {
|
||||||
|
"artifact_id": "abc123",
|
||||||
|
"dependency_project": "_pypi",
|
||||||
|
"dependency_package": "numpy",
|
||||||
|
"version_constraint": ">=1.21.0",
|
||||||
|
}
|
||||||
|
assert values[2]["dependency_project"] == "myproject"
|
||||||
|
|
||||||
|
def test_empty_dependencies_returns_empty_list(self):
|
||||||
|
"""Empty dependency list should return empty values."""
|
||||||
|
from backend.app.db_utils import ArtifactRepository
|
||||||
|
|
||||||
|
values = ArtifactRepository._format_dependency_values("abc123", [])
|
||||||
|
|
||||||
|
assert values == []
|
||||||
|
|
||||||
|
def test_format_dependency_values_preserves_special_characters(self):
|
||||||
|
"""Version constraints with special characters should be preserved."""
|
||||||
|
from backend.app.db_utils import ArtifactRepository
|
||||||
|
|
||||||
|
deps = [
|
||||||
|
("_pypi", "package-name", ">=1.0.0,<2.0.0"),
|
||||||
|
("_pypi", "another_pkg", "~=1.4.2"),
|
||||||
|
]
|
||||||
|
|
||||||
|
values = ArtifactRepository._format_dependency_values("hash123", deps)
|
||||||
|
|
||||||
|
assert values[0]["version_constraint"] == ">=1.0.0,<2.0.0"
|
||||||
|
assert values[1]["version_constraint"] == "~=1.4.2"
|
||||||
|
|
||||||
|
def test_batch_upsert_dependencies_returns_zero_for_empty(self):
|
||||||
|
"""batch_upsert_dependencies should return 0 for empty list without DB call."""
|
||||||
|
from backend.app.db_utils import ArtifactRepository
|
||||||
|
|
||||||
|
mock_db = MagicMock()
|
||||||
|
repo = ArtifactRepository(mock_db)
|
||||||
|
|
||||||
|
result = repo.batch_upsert_dependencies("abc123", [])
|
||||||
|
|
||||||
|
assert result == 0
|
||||||
|
# Verify no DB operations were performed
|
||||||
|
mock_db.execute.assert_not_called()
|
||||||
|
|
||||||
|
def test_get_or_create_artifact_builds_correct_statement(self):
|
||||||
|
"""get_or_create_artifact should use ON CONFLICT DO UPDATE."""
|
||||||
|
from backend.app.db_utils import ArtifactRepository
|
||||||
|
from backend.app.models import Artifact
|
||||||
|
|
||||||
|
mock_db = MagicMock()
|
||||||
|
mock_result = MagicMock()
|
||||||
|
mock_artifact = MagicMock()
|
||||||
|
mock_artifact.ref_count = 1
|
||||||
|
mock_result.scalar_one.return_value = mock_artifact
|
||||||
|
mock_db.execute.return_value = mock_result
|
||||||
|
|
||||||
|
repo = ArtifactRepository(mock_db)
|
||||||
|
artifact, created = repo.get_or_create_artifact(
|
||||||
|
sha256="abc123def456",
|
||||||
|
size=1024,
|
||||||
|
filename="test.whl",
|
||||||
|
content_type="application/zip",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert mock_db.execute.called
|
||||||
|
assert created is True
|
||||||
|
assert artifact == mock_artifact
|
||||||
|
|
||||||
|
def test_get_or_create_artifact_existing_not_created(self):
|
||||||
|
"""get_or_create_artifact should return created=False for existing artifact."""
|
||||||
|
from backend.app.db_utils import ArtifactRepository
|
||||||
|
|
||||||
|
mock_db = MagicMock()
|
||||||
|
mock_result = MagicMock()
|
||||||
|
mock_artifact = MagicMock()
|
||||||
|
mock_artifact.ref_count = 5 # Existing artifact with ref_count > 1
|
||||||
|
mock_result.scalar_one.return_value = mock_artifact
|
||||||
|
mock_db.execute.return_value = mock_result
|
||||||
|
|
||||||
|
repo = ArtifactRepository(mock_db)
|
||||||
|
artifact, created = repo.get_or_create_artifact(
|
||||||
|
sha256="abc123def456",
|
||||||
|
size=1024,
|
||||||
|
filename="test.whl",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert created is False
|
||||||
|
|
||||||
|
def test_get_cached_url_with_artifact_returns_tuple(self):
|
||||||
|
"""get_cached_url_with_artifact should return (CachedUrl, Artifact) tuple."""
|
||||||
|
from backend.app.db_utils import ArtifactRepository
|
||||||
|
|
||||||
|
mock_db = MagicMock()
|
||||||
|
mock_cached_url = MagicMock()
|
||||||
|
mock_artifact = MagicMock()
|
||||||
|
mock_db.query.return_value.join.return_value.filter.return_value.first.return_value = (
|
||||||
|
mock_cached_url,
|
||||||
|
mock_artifact,
|
||||||
|
)
|
||||||
|
|
||||||
|
repo = ArtifactRepository(mock_db)
|
||||||
|
result = repo.get_cached_url_with_artifact("url_hash_123")
|
||||||
|
|
||||||
|
assert result == (mock_cached_url, mock_artifact)
|
||||||
|
|
||||||
|
def test_get_cached_url_with_artifact_returns_none_when_not_found(self):
|
||||||
|
"""get_cached_url_with_artifact should return None when URL not cached."""
|
||||||
|
from backend.app.db_utils import ArtifactRepository
|
||||||
|
|
||||||
|
mock_db = MagicMock()
|
||||||
|
mock_db.query.return_value.join.return_value.filter.return_value.first.return_value = None
|
||||||
|
|
||||||
|
repo = ArtifactRepository(mock_db)
|
||||||
|
result = repo.get_cached_url_with_artifact("nonexistent_hash")
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_get_artifact_dependencies_returns_list(self):
|
||||||
|
"""get_artifact_dependencies should return list of dependencies."""
|
||||||
|
from backend.app.db_utils import ArtifactRepository
|
||||||
|
|
||||||
|
mock_db = MagicMock()
|
||||||
|
mock_dep1 = MagicMock()
|
||||||
|
mock_dep2 = MagicMock()
|
||||||
|
mock_db.query.return_value.filter.return_value.all.return_value = [
|
||||||
|
mock_dep1,
|
||||||
|
mock_dep2,
|
||||||
|
]
|
||||||
|
|
||||||
|
repo = ArtifactRepository(mock_db)
|
||||||
|
result = repo.get_artifact_dependencies("artifact_hash_123")
|
||||||
|
|
||||||
|
assert len(result) == 2
|
||||||
|
assert result[0] == mock_dep1
|
||||||
|
assert result[1] == mock_dep2
|
||||||
|
|
||||||
|
def test_get_artifact_dependencies_returns_empty_list(self):
|
||||||
|
"""get_artifact_dependencies should return empty list when no dependencies."""
|
||||||
|
from backend.app.db_utils import ArtifactRepository
|
||||||
|
|
||||||
|
mock_db = MagicMock()
|
||||||
|
mock_db.query.return_value.filter.return_value.all.return_value = []
|
||||||
|
|
||||||
|
repo = ArtifactRepository(mock_db)
|
||||||
|
result = repo.get_artifact_dependencies("artifact_without_deps")
|
||||||
|
|
||||||
|
assert result == []
|
||||||
Reference in New Issue
Block a user