Implement database storage layer

This commit is contained in:
Mondo Diaz
2025-12-12 12:45:33 -06:00
parent a6df5aba5a
commit 9604540dd3
16 changed files with 1477 additions and 2 deletions

View File

@@ -0,0 +1,22 @@
"""
Repository pattern implementation for data access layer.
Repositories abstract database operations from business logic,
providing clean interfaces for CRUD operations on each entity.
"""
from .base import BaseRepository
from .project import ProjectRepository
from .package import PackageRepository
from .artifact import ArtifactRepository
from .tag import TagRepository
from .upload import UploadRepository
__all__ = [
"BaseRepository",
"ProjectRepository",
"PackageRepository",
"ArtifactRepository",
"TagRepository",
"UploadRepository",
]

View File

@@ -0,0 +1,157 @@
"""
Artifact repository for data access operations.
"""
from typing import Optional, List, Tuple
from sqlalchemy.orm import Session
from sqlalchemy import func, or_
from uuid import UUID
from .base import BaseRepository
from ..models import Artifact, Tag, Upload, Package, Project
class ArtifactRepository(BaseRepository[Artifact]):
"""Repository for Artifact entity operations."""
model = Artifact
def get_by_sha256(self, sha256: str) -> Optional[Artifact]:
"""Get artifact by SHA256 hash (primary key)."""
return self.db.query(Artifact).filter(Artifact.id == sha256).first()
def exists_by_sha256(self, sha256: str) -> bool:
"""Check if artifact with SHA256 exists."""
return self.db.query(
self.db.query(Artifact).filter(Artifact.id == sha256).exists()
).scalar()
def create_artifact(
self,
sha256: str,
size: int,
s3_key: str,
created_by: str,
content_type: Optional[str] = None,
original_name: Optional[str] = None,
format_metadata: Optional[dict] = None,
) -> Artifact:
"""Create a new artifact."""
artifact = Artifact(
id=sha256,
size=size,
s3_key=s3_key,
created_by=created_by,
content_type=content_type,
original_name=original_name,
format_metadata=format_metadata or {},
ref_count=1,
)
self.db.add(artifact)
self.db.flush()
return artifact
def increment_ref_count(self, artifact: Artifact) -> Artifact:
"""Increment artifact reference count."""
artifact.ref_count += 1
self.db.flush()
return artifact
def decrement_ref_count(self, artifact: Artifact) -> Artifact:
"""
Decrement artifact reference count.
Returns the artifact with updated count.
Does not delete the artifact even if ref_count reaches 0.
"""
if artifact.ref_count > 0:
artifact.ref_count -= 1
self.db.flush()
return artifact
def get_orphaned_artifacts(self, limit: int = 100) -> List[Artifact]:
"""Get artifacts with ref_count = 0 (candidates for cleanup)."""
return (
self.db.query(Artifact)
.filter(Artifact.ref_count == 0)
.limit(limit)
.all()
)
def get_artifacts_without_tags(self, limit: int = 100) -> List[Artifact]:
"""Get artifacts that have no tags pointing to them."""
# Subquery to find artifact IDs that have tags
tagged_artifacts = self.db.query(Tag.artifact_id).distinct().subquery()
return (
self.db.query(Artifact)
.filter(~Artifact.id.in_(tagged_artifacts))
.limit(limit)
.all()
)
def find_by_package(
self,
package_id: UUID,
page: int = 1,
limit: int = 20,
content_type: Optional[str] = None,
) -> Tuple[List[Artifact], int]:
"""Find artifacts uploaded to a package."""
# Get distinct artifact IDs from uploads
artifact_ids_subquery = (
self.db.query(func.distinct(Upload.artifact_id))
.filter(Upload.package_id == package_id)
.subquery()
)
query = self.db.query(Artifact).filter(Artifact.id.in_(artifact_ids_subquery))
if content_type:
query = query.filter(Artifact.content_type == content_type)
total = query.count()
offset = (page - 1) * limit
artifacts = query.order_by(Artifact.created_at.desc()).offset(offset).limit(limit).all()
return artifacts, total
def get_referencing_tags(self, artifact_id: str) -> List[Tuple[Tag, Package, Project]]:
"""Get all tags referencing this artifact with package and project info."""
return (
self.db.query(Tag, Package, Project)
.join(Package, Tag.package_id == Package.id)
.join(Project, Package.project_id == Project.id)
.filter(Tag.artifact_id == artifact_id)
.all()
)
def search(self, query_str: str, limit: int = 10) -> List[Tuple[Tag, Artifact, str, str]]:
"""
Search artifacts by tag name or original filename.
Returns (tag, artifact, package_name, project_name) tuples.
"""
search_lower = query_str.lower()
return (
self.db.query(Tag, Artifact, Package.name, Project.name)
.join(Artifact, Tag.artifact_id == Artifact.id)
.join(Package, Tag.package_id == Package.id)
.join(Project, Package.project_id == Project.id)
.filter(
or_(
func.lower(Tag.name).contains(search_lower),
func.lower(Artifact.original_name).contains(search_lower)
)
)
.order_by(Tag.name)
.limit(limit)
.all()
)
def update_metadata(self, artifact: Artifact, metadata: dict) -> Artifact:
"""Update or merge format metadata."""
if artifact.format_metadata:
artifact.format_metadata = {**artifact.format_metadata, **metadata}
else:
artifact.format_metadata = metadata
self.db.flush()
return artifact

View File

@@ -0,0 +1,96 @@
"""
Base repository class with common CRUD operations.
"""
from typing import TypeVar, Generic, Type, Optional, List, Any, Dict
from sqlalchemy.orm import Session
from sqlalchemy import func, asc, desc
from uuid import UUID
from ..models import Base
T = TypeVar("T", bound=Base)
class BaseRepository(Generic[T]):
"""
Base repository providing common CRUD operations.
Subclasses should set `model` class attribute to the SQLAlchemy model.
"""
model: Type[T]
def __init__(self, db: Session):
self.db = db
def get_by_id(self, id: Any) -> Optional[T]:
"""Get entity by primary key."""
return self.db.query(self.model).filter(self.model.id == id).first()
def get_all(
self,
skip: int = 0,
limit: int = 100,
order_by: str = None,
order_desc: bool = False,
) -> List[T]:
"""Get all entities with pagination and optional ordering."""
query = self.db.query(self.model)
if order_by and hasattr(self.model, order_by):
column = getattr(self.model, order_by)
query = query.order_by(desc(column) if order_desc else asc(column))
return query.offset(skip).limit(limit).all()
def count(self) -> int:
"""Count total entities."""
return self.db.query(func.count(self.model.id)).scalar() or 0
def create(self, **kwargs) -> T:
"""Create a new entity."""
entity = self.model(**kwargs)
self.db.add(entity)
self.db.flush() # Flush to get ID without committing
return entity
def update(self, entity: T, **kwargs) -> T:
"""Update an existing entity."""
for key, value in kwargs.items():
if hasattr(entity, key):
setattr(entity, key, value)
self.db.flush()
return entity
def delete(self, entity: T) -> None:
"""Delete an entity."""
self.db.delete(entity)
self.db.flush()
def delete_by_id(self, id: Any) -> bool:
"""Delete entity by ID. Returns True if deleted, False if not found."""
entity = self.get_by_id(id)
if entity:
self.delete(entity)
return True
return False
def exists(self, id: Any) -> bool:
"""Check if entity exists by ID."""
return self.db.query(
self.db.query(self.model).filter(self.model.id == id).exists()
).scalar()
def commit(self) -> None:
"""Commit the current transaction."""
self.db.commit()
def rollback(self) -> None:
"""Rollback the current transaction."""
self.db.rollback()
def refresh(self, entity: T) -> T:
"""Refresh entity from database."""
self.db.refresh(entity)
return entity

View File

@@ -0,0 +1,177 @@
"""
Package repository for data access operations.
"""
from typing import Optional, List, Tuple
from sqlalchemy.orm import Session
from sqlalchemy import func, or_, asc, desc
from uuid import UUID
from .base import BaseRepository
from ..models import Package, Project, Tag, Upload, Artifact
class PackageRepository(BaseRepository[Package]):
"""Repository for Package entity operations."""
model = Package
def get_by_name(self, project_id: UUID, name: str) -> Optional[Package]:
"""Get package by name within a project."""
return (
self.db.query(Package)
.filter(Package.project_id == project_id, Package.name == name)
.first()
)
def get_by_project_and_name(self, project_name: str, package_name: str) -> Optional[Package]:
"""Get package by project name and package name."""
return (
self.db.query(Package)
.join(Project, Package.project_id == Project.id)
.filter(Project.name == project_name, Package.name == package_name)
.first()
)
def exists_by_name(self, project_id: UUID, name: str) -> bool:
"""Check if package with name exists in project."""
return self.db.query(
self.db.query(Package)
.filter(Package.project_id == project_id, Package.name == name)
.exists()
).scalar()
def list_by_project(
self,
project_id: UUID,
page: int = 1,
limit: int = 20,
search: Optional[str] = None,
format: Optional[str] = None,
platform: Optional[str] = None,
sort: str = "name",
order: str = "asc",
) -> Tuple[List[Package], int]:
"""
List packages in a project with filtering and pagination.
Returns tuple of (packages, total_count).
"""
query = self.db.query(Package).filter(Package.project_id == project_id)
# Apply search filter
if search:
search_lower = search.lower()
query = query.filter(
or_(
func.lower(Package.name).contains(search_lower),
func.lower(Package.description).contains(search_lower)
)
)
# Apply format filter
if format:
query = query.filter(Package.format == format)
# Apply platform filter
if platform:
query = query.filter(Package.platform == platform)
# Get total count
total = query.count()
# Apply sorting
sort_columns = {
"name": Package.name,
"created_at": Package.created_at,
"updated_at": Package.updated_at,
}
sort_column = sort_columns.get(sort, Package.name)
if order == "desc":
query = query.order_by(desc(sort_column))
else:
query = query.order_by(asc(sort_column))
# Apply pagination
offset = (page - 1) * limit
packages = query.offset(offset).limit(limit).all()
return packages, total
def create_package(
self,
project_id: UUID,
name: str,
description: Optional[str] = None,
format: str = "generic",
platform: str = "any",
) -> Package:
"""Create a new package."""
return self.create(
project_id=project_id,
name=name,
description=description,
format=format,
platform=platform,
)
def update_package(
self,
package: Package,
name: Optional[str] = None,
description: Optional[str] = None,
format: Optional[str] = None,
platform: Optional[str] = None,
) -> Package:
"""Update package fields."""
updates = {}
if name is not None:
updates["name"] = name
if description is not None:
updates["description"] = description
if format is not None:
updates["format"] = format
if platform is not None:
updates["platform"] = platform
return self.update(package, **updates)
def get_stats(self, package_id: UUID) -> dict:
"""Get package statistics (tag count, artifact count, total size)."""
tag_count = (
self.db.query(func.count(Tag.id))
.filter(Tag.package_id == package_id)
.scalar() or 0
)
artifact_stats = (
self.db.query(
func.count(func.distinct(Upload.artifact_id)),
func.coalesce(func.sum(Artifact.size), 0)
)
.join(Artifact, Upload.artifact_id == Artifact.id)
.filter(Upload.package_id == package_id)
.first()
)
return {
"tag_count": tag_count,
"artifact_count": artifact_stats[0] if artifact_stats else 0,
"total_size": artifact_stats[1] if artifact_stats else 0,
}
def search(self, query_str: str, limit: int = 10) -> List[Tuple[Package, str]]:
"""Search packages by name or description. Returns (package, project_name) tuples."""
search_lower = query_str.lower()
return (
self.db.query(Package, Project.name)
.join(Project, Package.project_id == Project.id)
.filter(
or_(
func.lower(Package.name).contains(search_lower),
func.lower(Package.description).contains(search_lower)
)
)
.order_by(Package.name)
.limit(limit)
.all()
)

View File

@@ -0,0 +1,132 @@
"""
Project repository for data access operations.
"""
from typing import Optional, List, Tuple
from sqlalchemy.orm import Session
from sqlalchemy import func, or_, asc, desc
from uuid import UUID
from .base import BaseRepository
from ..models import Project
class ProjectRepository(BaseRepository[Project]):
"""Repository for Project entity operations."""
model = Project
def get_by_name(self, name: str) -> Optional[Project]:
"""Get project by unique name."""
return self.db.query(Project).filter(Project.name == name).first()
def exists_by_name(self, name: str) -> bool:
"""Check if project with name exists."""
return self.db.query(
self.db.query(Project).filter(Project.name == name).exists()
).scalar()
def list_accessible(
self,
user_id: str,
page: int = 1,
limit: int = 20,
search: Optional[str] = None,
visibility: Optional[str] = None,
sort: str = "name",
order: str = "asc",
) -> Tuple[List[Project], int]:
"""
List projects accessible to user with filtering and pagination.
Returns tuple of (projects, total_count).
"""
# Base query - filter by access
query = self.db.query(Project).filter(
or_(Project.is_public == True, Project.created_by == user_id)
)
# Apply visibility filter
if visibility == "public":
query = query.filter(Project.is_public == True)
elif visibility == "private":
query = query.filter(Project.is_public == False, Project.created_by == user_id)
# Apply search filter
if search:
search_lower = search.lower()
query = query.filter(
or_(
func.lower(Project.name).contains(search_lower),
func.lower(Project.description).contains(search_lower)
)
)
# Get total count before pagination
total = query.count()
# Apply sorting
sort_columns = {
"name": Project.name,
"created_at": Project.created_at,
"updated_at": Project.updated_at,
}
sort_column = sort_columns.get(sort, Project.name)
if order == "desc":
query = query.order_by(desc(sort_column))
else:
query = query.order_by(asc(sort_column))
# Apply pagination
offset = (page - 1) * limit
projects = query.offset(offset).limit(limit).all()
return projects, total
def create_project(
self,
name: str,
created_by: str,
description: Optional[str] = None,
is_public: bool = True,
) -> Project:
"""Create a new project."""
return self.create(
name=name,
description=description,
is_public=is_public,
created_by=created_by,
)
def update_project(
self,
project: Project,
name: Optional[str] = None,
description: Optional[str] = None,
is_public: Optional[bool] = None,
) -> Project:
"""Update project fields."""
updates = {}
if name is not None:
updates["name"] = name
if description is not None:
updates["description"] = description
if is_public is not None:
updates["is_public"] = is_public
return self.update(project, **updates)
def search(self, query_str: str, limit: int = 10) -> List[Project]:
"""Search projects by name or description."""
search_lower = query_str.lower()
return (
self.db.query(Project)
.filter(
or_(
func.lower(Project.name).contains(search_lower),
func.lower(Project.description).contains(search_lower)
)
)
.order_by(Project.name)
.limit(limit)
.all()
)

View File

@@ -0,0 +1,168 @@
"""
Tag repository for data access operations.
"""
from typing import Optional, List, Tuple
from sqlalchemy.orm import Session
from sqlalchemy import func, or_, asc, desc
from uuid import UUID
from .base import BaseRepository
from ..models import Tag, TagHistory, Artifact, Package, Project
class TagRepository(BaseRepository[Tag]):
"""Repository for Tag entity operations."""
model = Tag
def get_by_name(self, package_id: UUID, name: str) -> Optional[Tag]:
"""Get tag by name within a package."""
return (
self.db.query(Tag)
.filter(Tag.package_id == package_id, Tag.name == name)
.first()
)
def get_with_artifact(self, package_id: UUID, name: str) -> Optional[Tuple[Tag, Artifact]]:
"""Get tag with its artifact."""
return (
self.db.query(Tag, Artifact)
.join(Artifact, Tag.artifact_id == Artifact.id)
.filter(Tag.package_id == package_id, Tag.name == name)
.first()
)
def exists_by_name(self, package_id: UUID, name: str) -> bool:
"""Check if tag with name exists in package."""
return self.db.query(
self.db.query(Tag)
.filter(Tag.package_id == package_id, Tag.name == name)
.exists()
).scalar()
def list_by_package(
self,
package_id: UUID,
page: int = 1,
limit: int = 20,
search: Optional[str] = None,
sort: str = "name",
order: str = "asc",
) -> Tuple[List[Tuple[Tag, Artifact]], int]:
"""
List tags in a package with artifact metadata.
Returns tuple of ((tag, artifact) tuples, total_count).
"""
query = (
self.db.query(Tag, Artifact)
.join(Artifact, Tag.artifact_id == Artifact.id)
.filter(Tag.package_id == package_id)
)
# Apply search filter (tag name or artifact original filename)
if search:
search_lower = search.lower()
query = query.filter(
or_(
func.lower(Tag.name).contains(search_lower),
func.lower(Artifact.original_name).contains(search_lower)
)
)
# Get total count
total = query.count()
# Apply sorting
sort_columns = {
"name": Tag.name,
"created_at": Tag.created_at,
}
sort_column = sort_columns.get(sort, Tag.name)
if order == "desc":
query = query.order_by(desc(sort_column))
else:
query = query.order_by(asc(sort_column))
# Apply pagination
offset = (page - 1) * limit
results = query.offset(offset).limit(limit).all()
return results, total
def create_tag(
self,
package_id: UUID,
name: str,
artifact_id: str,
created_by: str,
) -> Tag:
"""Create a new tag."""
return self.create(
package_id=package_id,
name=name,
artifact_id=artifact_id,
created_by=created_by,
)
def update_artifact(
self,
tag: Tag,
new_artifact_id: str,
changed_by: str,
record_history: bool = True,
) -> Tag:
"""
Update tag to point to a different artifact.
Optionally records change in tag history.
"""
old_artifact_id = tag.artifact_id
if record_history and old_artifact_id != new_artifact_id:
history = TagHistory(
tag_id=tag.id,
old_artifact_id=old_artifact_id,
new_artifact_id=new_artifact_id,
changed_by=changed_by,
)
self.db.add(history)
tag.artifact_id = new_artifact_id
tag.created_by = changed_by
self.db.flush()
return tag
def get_history(self, tag_id: UUID) -> List[TagHistory]:
"""Get tag change history."""
return (
self.db.query(TagHistory)
.filter(TagHistory.tag_id == tag_id)
.order_by(TagHistory.changed_at.desc())
.all()
)
def get_latest_in_package(self, package_id: UUID) -> Optional[Tag]:
"""Get the most recently created/updated tag in a package."""
return (
self.db.query(Tag)
.filter(Tag.package_id == package_id)
.order_by(Tag.created_at.desc())
.first()
)
def get_by_artifact(self, artifact_id: str) -> List[Tag]:
"""Get all tags pointing to an artifact."""
return (
self.db.query(Tag)
.filter(Tag.artifact_id == artifact_id)
.all()
)
def count_by_artifact(self, artifact_id: str) -> int:
"""Count tags pointing to an artifact."""
return (
self.db.query(func.count(Tag.id))
.filter(Tag.artifact_id == artifact_id)
.scalar() or 0
)

View File

@@ -0,0 +1,136 @@
"""
Upload repository for data access operations.
"""
from typing import Optional, List, Tuple
from datetime import datetime
from sqlalchemy.orm import Session
from sqlalchemy import func, desc
from uuid import UUID
from .base import BaseRepository
from ..models import Upload, Artifact, Package, Project
class UploadRepository(BaseRepository[Upload]):
"""Repository for Upload entity operations."""
model = Upload
def create_upload(
self,
artifact_id: str,
package_id: UUID,
uploaded_by: str,
original_name: Optional[str] = None,
source_ip: Optional[str] = None,
) -> Upload:
"""Record a new upload event."""
return self.create(
artifact_id=artifact_id,
package_id=package_id,
original_name=original_name,
uploaded_by=uploaded_by,
source_ip=source_ip,
)
def list_by_package(
self,
package_id: UUID,
page: int = 1,
limit: int = 20,
) -> Tuple[List[Upload], int]:
"""List uploads for a package with pagination."""
query = self.db.query(Upload).filter(Upload.package_id == package_id)
total = query.count()
offset = (page - 1) * limit
uploads = query.order_by(Upload.uploaded_at.desc()).offset(offset).limit(limit).all()
return uploads, total
def list_by_artifact(self, artifact_id: str) -> List[Upload]:
"""List all uploads of a specific artifact."""
return (
self.db.query(Upload)
.filter(Upload.artifact_id == artifact_id)
.order_by(Upload.uploaded_at.desc())
.all()
)
def get_latest_for_package(self, package_id: UUID) -> Optional[Upload]:
"""Get the most recent upload for a package."""
return (
self.db.query(Upload)
.filter(Upload.package_id == package_id)
.order_by(Upload.uploaded_at.desc())
.first()
)
def get_latest_timestamp(self, package_id: UUID) -> Optional[datetime]:
"""Get timestamp of most recent upload for a package."""
result = (
self.db.query(func.max(Upload.uploaded_at))
.filter(Upload.package_id == package_id)
.scalar()
)
return result
def count_by_artifact(self, artifact_id: str) -> int:
"""Count uploads of a specific artifact."""
return (
self.db.query(func.count(Upload.id))
.filter(Upload.artifact_id == artifact_id)
.scalar() or 0
)
def count_by_package(self, package_id: UUID) -> int:
"""Count total uploads for a package."""
return (
self.db.query(func.count(Upload.id))
.filter(Upload.package_id == package_id)
.scalar() or 0
)
def get_distinct_artifacts_count(self, package_id: UUID) -> int:
"""Count distinct artifacts uploaded to a package."""
return (
self.db.query(func.count(func.distinct(Upload.artifact_id)))
.filter(Upload.package_id == package_id)
.scalar() or 0
)
def get_uploads_by_user(
self,
user_id: str,
page: int = 1,
limit: int = 20,
) -> Tuple[List[Upload], int]:
"""List uploads by a specific user."""
query = self.db.query(Upload).filter(Upload.uploaded_by == user_id)
total = query.count()
offset = (page - 1) * limit
uploads = query.order_by(Upload.uploaded_at.desc()).offset(offset).limit(limit).all()
return uploads, total
def get_upload_stats(self, package_id: UUID) -> dict:
"""Get upload statistics for a package."""
stats = (
self.db.query(
func.count(Upload.id),
func.count(func.distinct(Upload.artifact_id)),
func.min(Upload.uploaded_at),
func.max(Upload.uploaded_at),
)
.filter(Upload.package_id == package_id)
.first()
)
return {
"total_uploads": stats[0] if stats else 0,
"unique_artifacts": stats[1] if stats else 0,
"first_upload": stats[2] if stats else None,
"last_upload": stats[3] if stats else None,
}