1 Commits

Author SHA1 Message Date
Mondo Diaz
ec6b3f0ed8 Implement database storage layer
- Add connection pool configuration (pool_size, max_overflow, timeout, recycle)
- Add transaction management utilities (transaction, savepoint, retry_on_deadlock)
- Create repository pattern classes for all entities (Project, Package, Artifact, Tag, Upload)
- Implement ref_count decrement and cleanup service
- Add query helper functions (search, filtering, pagination, stats)
- Add database constraints (check_ref_count_non_negative, check_size_positive)
- Add performance indexes (idx_artifacts_ref_count, composite indexes for packages/tags)
- Initialize Alembic migrations for future schema changes
2025-12-12 12:18:01 -06:00
16 changed files with 1477 additions and 2 deletions

83
backend/alembic.ini Normal file
View File

@@ -0,0 +1,83 @@
# Alembic Configuration File
[alembic]
# path to migration scripts
script_location = alembic
# template used to generate migration files
# file_template = %%(rev)s_%%(slug)s
# sys.path path, will be prepended to sys.path if present.
prepend_sys_path = .
# timezone to use when rendering the date within the migration file
# as well as the filename.
# If specified, requires the python-dateutil library
# timezone =
# max length of characters to apply to the "slug" field
# truncate_slug_length = 40
# set to 'true' to run the environment during the 'revision' command,
# regardless of autogenerate
# revision_environment = false
# set to 'true' to allow .pyc and .pyo files without a source .py file
# to be detected as revisions in the versions/ directory
# sourceless = false
# version location specification
# version_locations = %(here)s/bar:%(here)s/bat:alembic/versions
# version path separator
# version_path_separator = :
# set to 'true' to search source files recursively
# in each "version_locations" directory
# recursive_version_locations = false
# the output encoding used when revision files are written from script.py.mako
# output_encoding = utf-8
# Database URL - will be overridden by env.py
sqlalchemy.url = driver://user:pass@localhost/dbname
[post_write_hooks]
# post_write_hooks defines scripts or Python functions that are run
# on newly generated revision scripts.
# Logging configuration
[loggers]
keys = root,sqlalchemy,alembic
[handlers]
keys = console
[formatters]
keys = generic
[logger_root]
level = WARN
handlers = console
qualname =
[logger_sqlalchemy]
level = WARN
handlers =
qualname = sqlalchemy.engine
[logger_alembic]
level = INFO
handlers =
qualname = alembic
[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic
[formatter_generic]
format = %(levelname)-5.5s [%(name)s] %(message)s
datefmt = %H:%M:%S

27
backend/alembic/README Normal file
View File

@@ -0,0 +1,27 @@
Alembic Migrations for Orchard
This directory contains database migration scripts managed by Alembic.
Common Commands:
# Generate a new migration (autogenerate from model changes)
alembic revision --autogenerate -m "description of changes"
# Apply all pending migrations
alembic upgrade head
# Rollback one migration
alembic downgrade -1
# Show current migration status
alembic current
# Show migration history
alembic history
# Generate SQL without applying (for review)
alembic upgrade head --sql
Notes:
- Always review autogenerated migrations before applying
- Test migrations in development before applying to production
- Migrations are stored in the versions/ directory

95
backend/alembic/env.py Normal file
View File

@@ -0,0 +1,95 @@
"""
Alembic migration environment configuration.
"""
from logging.config import fileConfig
import sys
from pathlib import Path
from sqlalchemy import engine_from_config
from sqlalchemy import pool
from alembic import context
# Add the app directory to path for imports
sys.path.insert(0, str(Path(__file__).parent.parent))
from app.config import get_settings
from app.models import Base
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
config = context.config
# Get database URL from settings
settings = get_settings()
config.set_main_option("sqlalchemy.url", settings.database_url)
# Interpret the config file for Python logging.
# This line sets up loggers basically.
if config.config_file_name is not None:
fileConfig(config.config_file_name)
# add your model's MetaData object here
# for 'autogenerate' support
target_metadata = Base.metadata
# other values from the config, defined by the needs of env.py,
# can be acquired:
# my_important_option = config.get_main_option("my_important_option")
# ... etc.
def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode.
This configures the context with just a URL
and not an Engine, though an Engine is acceptable
here as well. By skipping the Engine creation
we don't even need a DBAPI to be available.
Calls to context.execute() here emit the given string to the
script output.
"""
url = config.get_main_option("sqlalchemy.url")
context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
)
with context.begin_transaction():
context.run_migrations()
def run_migrations_online() -> None:
"""Run migrations in 'online' mode.
In this scenario we need to create an Engine
and associate a connection with the context.
"""
connectable = engine_from_config(
config.get_section(config.config_ini_section, {}),
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
with connectable.connect() as connection:
context.configure(
connection=connection,
target_metadata=target_metadata,
compare_type=True, # Detect column type changes
compare_server_default=True, # Detect default value changes
)
with context.begin_transaction():
context.run_migrations()
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()

View File

@@ -0,0 +1,26 @@
"""${message}
Revision ID: ${up_revision}
Revises: ${down_revision | comma,n}
Create Date: ${create_date}
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
${imports if imports else ""}
# revision identifiers, used by Alembic.
revision: str = ${repr(up_revision)}
down_revision: Union[str, None] = ${repr(down_revision)}
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
def upgrade() -> None:
${upgrades if upgrades else "pass"}
def downgrade() -> None:
${downgrades if downgrades else "pass"}

View File

@@ -18,6 +18,12 @@ class Settings(BaseSettings):
database_dbname: str = "orchard" database_dbname: str = "orchard"
database_sslmode: str = "disable" database_sslmode: str = "disable"
# Database connection pool settings
database_pool_size: int = 5 # Number of connections to keep open
database_max_overflow: int = 10 # Max additional connections beyond pool_size
database_pool_timeout: int = 30 # Seconds to wait for a connection from pool
database_pool_recycle: int = 1800 # Recycle connections after this many seconds (30 min)
# S3 # S3
s3_endpoint: str = "" s3_endpoint: str = ""
s3_region: str = "us-east-1" s3_region: str = "us-east-1"

View File

@@ -1,7 +1,10 @@
from sqlalchemy import create_engine, text from sqlalchemy import create_engine, text, event
from sqlalchemy.orm import sessionmaker, Session from sqlalchemy.orm import sessionmaker, Session
from sqlalchemy.pool import QueuePool
from typing import Generator from typing import Generator
from contextlib import contextmanager
import logging import logging
import time
from .config import get_settings from .config import get_settings
from .models import Base from .models import Base
@@ -9,10 +12,44 @@ from .models import Base
settings = get_settings() settings = get_settings()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
engine = create_engine(settings.database_url, pool_pre_ping=True) # Create engine with connection pool configuration
engine = create_engine(
settings.database_url,
pool_pre_ping=True, # Check connection health before using
poolclass=QueuePool,
pool_size=settings.database_pool_size,
max_overflow=settings.database_max_overflow,
pool_timeout=settings.database_pool_timeout,
pool_recycle=settings.database_pool_recycle,
)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
# Connection pool monitoring
@event.listens_for(engine, "checkout")
def receive_checkout(dbapi_connection, connection_record, connection_proxy):
"""Log when a connection is checked out from the pool"""
logger.debug(f"Connection checked out from pool: {id(dbapi_connection)}")
@event.listens_for(engine, "checkin")
def receive_checkin(dbapi_connection, connection_record):
"""Log when a connection is returned to the pool"""
logger.debug(f"Connection returned to pool: {id(dbapi_connection)}")
def get_pool_status() -> dict:
"""Get current connection pool status for monitoring"""
pool = engine.pool
return {
"pool_size": pool.size(),
"checked_out": pool.checkedout(),
"overflow": pool.overflow(),
"checked_in": pool.checkedin(),
}
def init_db(): def init_db():
"""Create all tables and run migrations""" """Create all tables and run migrations"""
Base.metadata.create_all(bind=engine) Base.metadata.create_all(bind=engine)
@@ -62,6 +99,51 @@ def _run_migrations():
END IF; END IF;
END $$; END $$;
""", """,
# Add ref_count index and constraints for artifacts
"""
DO $$
BEGIN
-- Add ref_count index
IF NOT EXISTS (
SELECT 1 FROM pg_indexes WHERE indexname = 'idx_artifacts_ref_count'
) THEN
CREATE INDEX idx_artifacts_ref_count ON artifacts(ref_count);
END IF;
-- Add ref_count >= 0 constraint
IF NOT EXISTS (
SELECT 1 FROM pg_constraint WHERE conname = 'check_ref_count_non_negative'
) THEN
ALTER TABLE artifacts ADD CONSTRAINT check_ref_count_non_negative CHECK (ref_count >= 0);
END IF;
END $$;
""",
# Add composite indexes for packages and tags
"""
DO $$
BEGIN
-- Composite index for package lookup by project and name
IF NOT EXISTS (
SELECT 1 FROM pg_indexes WHERE indexname = 'idx_packages_project_name'
) THEN
CREATE UNIQUE INDEX idx_packages_project_name ON packages(project_id, name);
END IF;
-- Composite index for tag lookup by package and name
IF NOT EXISTS (
SELECT 1 FROM pg_indexes WHERE indexname = 'idx_tags_package_name'
) THEN
CREATE UNIQUE INDEX idx_tags_package_name ON tags(package_id, name);
END IF;
-- Composite index for recent tags queries
IF NOT EXISTS (
SELECT 1 FROM pg_indexes WHERE indexname = 'idx_tags_package_created_at'
) THEN
CREATE INDEX idx_tags_package_created_at ON tags(package_id, created_at);
END IF;
END $$;
""",
] ]
with engine.connect() as conn: with engine.connect() as conn:
@@ -80,3 +162,75 @@ def get_db() -> Generator[Session, None, None]:
yield db yield db
finally: finally:
db.close() db.close()
@contextmanager
def transaction(db: Session):
"""
Context manager for explicit transaction management with savepoint support.
Usage:
with transaction(db):
# operations here
# automatically commits on success, rolls back on exception
"""
try:
yield db
db.commit()
except Exception:
db.rollback()
raise
@contextmanager
def savepoint(db: Session, name: str = None):
"""
Create a savepoint for partial rollback support.
Usage:
with savepoint(db, "my_savepoint"):
# operations here
# rolls back to savepoint on exception, but doesn't rollback whole transaction
"""
savepoint_obj = db.begin_nested()
try:
yield savepoint_obj
savepoint_obj.commit()
except Exception:
savepoint_obj.rollback()
raise
def retry_on_deadlock(func, max_retries: int = 3, delay: float = 0.1):
"""
Decorator/wrapper to retry operations on deadlock detection.
Usage:
@retry_on_deadlock
def my_operation(db):
...
Or:
retry_on_deadlock(lambda: my_operation(db))()
"""
import functools
from sqlalchemy.exc import OperationalError
@functools.wraps(func)
def wrapper(*args, **kwargs):
last_exception = None
for attempt in range(max_retries):
try:
return func(*args, **kwargs)
except OperationalError as e:
# Check for deadlock error codes (PostgreSQL: 40P01, MySQL: 1213)
error_str = str(e).lower()
if "deadlock" in error_str or "40p01" in error_str:
last_exception = e
logger.warning(f"Deadlock detected, retrying (attempt {attempt + 1}/{max_retries})")
time.sleep(delay * (attempt + 1)) # Exponential backoff
else:
raise
raise last_exception
return wrapper

View File

@@ -53,6 +53,7 @@ class Package(Base):
Index("idx_packages_name", "name"), Index("idx_packages_name", "name"),
Index("idx_packages_format", "format"), Index("idx_packages_format", "format"),
Index("idx_packages_platform", "platform"), Index("idx_packages_platform", "platform"),
Index("idx_packages_project_name", "project_id", "name", unique=True), # Composite unique index
CheckConstraint( CheckConstraint(
"format IN ('generic', 'npm', 'pypi', 'docker', 'deb', 'rpm', 'maven', 'nuget', 'helm')", "format IN ('generic', 'npm', 'pypi', 'docker', 'deb', 'rpm', 'maven', 'nuget', 'helm')",
name="check_package_format" name="check_package_format"
@@ -84,6 +85,9 @@ class Artifact(Base):
__table_args__ = ( __table_args__ = (
Index("idx_artifacts_created_at", "created_at"), Index("idx_artifacts_created_at", "created_at"),
Index("idx_artifacts_created_by", "created_by"), Index("idx_artifacts_created_by", "created_by"),
Index("idx_artifacts_ref_count", "ref_count"), # For cleanup queries
CheckConstraint("ref_count >= 0", name="check_ref_count_non_negative"),
CheckConstraint("size > 0", name="check_size_positive"),
) )
@@ -104,6 +108,8 @@ class Tag(Base):
__table_args__ = ( __table_args__ = (
Index("idx_tags_package_id", "package_id"), Index("idx_tags_package_id", "package_id"),
Index("idx_tags_artifact_id", "artifact_id"), Index("idx_tags_artifact_id", "artifact_id"),
Index("idx_tags_package_name", "package_id", "name", unique=True), # Composite unique index
Index("idx_tags_package_created_at", "package_id", "created_at"), # For recent tags queries
) )

View File

@@ -0,0 +1,22 @@
"""
Repository pattern implementation for data access layer.
Repositories abstract database operations from business logic,
providing clean interfaces for CRUD operations on each entity.
"""
from .base import BaseRepository
from .project import ProjectRepository
from .package import PackageRepository
from .artifact import ArtifactRepository
from .tag import TagRepository
from .upload import UploadRepository
__all__ = [
"BaseRepository",
"ProjectRepository",
"PackageRepository",
"ArtifactRepository",
"TagRepository",
"UploadRepository",
]

View File

@@ -0,0 +1,157 @@
"""
Artifact repository for data access operations.
"""
from typing import Optional, List, Tuple
from sqlalchemy.orm import Session
from sqlalchemy import func, or_
from uuid import UUID
from .base import BaseRepository
from ..models import Artifact, Tag, Upload, Package, Project
class ArtifactRepository(BaseRepository[Artifact]):
"""Repository for Artifact entity operations."""
model = Artifact
def get_by_sha256(self, sha256: str) -> Optional[Artifact]:
"""Get artifact by SHA256 hash (primary key)."""
return self.db.query(Artifact).filter(Artifact.id == sha256).first()
def exists_by_sha256(self, sha256: str) -> bool:
"""Check if artifact with SHA256 exists."""
return self.db.query(
self.db.query(Artifact).filter(Artifact.id == sha256).exists()
).scalar()
def create_artifact(
self,
sha256: str,
size: int,
s3_key: str,
created_by: str,
content_type: Optional[str] = None,
original_name: Optional[str] = None,
format_metadata: Optional[dict] = None,
) -> Artifact:
"""Create a new artifact."""
artifact = Artifact(
id=sha256,
size=size,
s3_key=s3_key,
created_by=created_by,
content_type=content_type,
original_name=original_name,
format_metadata=format_metadata or {},
ref_count=1,
)
self.db.add(artifact)
self.db.flush()
return artifact
def increment_ref_count(self, artifact: Artifact) -> Artifact:
"""Increment artifact reference count."""
artifact.ref_count += 1
self.db.flush()
return artifact
def decrement_ref_count(self, artifact: Artifact) -> Artifact:
"""
Decrement artifact reference count.
Returns the artifact with updated count.
Does not delete the artifact even if ref_count reaches 0.
"""
if artifact.ref_count > 0:
artifact.ref_count -= 1
self.db.flush()
return artifact
def get_orphaned_artifacts(self, limit: int = 100) -> List[Artifact]:
"""Get artifacts with ref_count = 0 (candidates for cleanup)."""
return (
self.db.query(Artifact)
.filter(Artifact.ref_count == 0)
.limit(limit)
.all()
)
def get_artifacts_without_tags(self, limit: int = 100) -> List[Artifact]:
"""Get artifacts that have no tags pointing to them."""
# Subquery to find artifact IDs that have tags
tagged_artifacts = self.db.query(Tag.artifact_id).distinct().subquery()
return (
self.db.query(Artifact)
.filter(~Artifact.id.in_(tagged_artifacts))
.limit(limit)
.all()
)
def find_by_package(
self,
package_id: UUID,
page: int = 1,
limit: int = 20,
content_type: Optional[str] = None,
) -> Tuple[List[Artifact], int]:
"""Find artifacts uploaded to a package."""
# Get distinct artifact IDs from uploads
artifact_ids_subquery = (
self.db.query(func.distinct(Upload.artifact_id))
.filter(Upload.package_id == package_id)
.subquery()
)
query = self.db.query(Artifact).filter(Artifact.id.in_(artifact_ids_subquery))
if content_type:
query = query.filter(Artifact.content_type == content_type)
total = query.count()
offset = (page - 1) * limit
artifacts = query.order_by(Artifact.created_at.desc()).offset(offset).limit(limit).all()
return artifacts, total
def get_referencing_tags(self, artifact_id: str) -> List[Tuple[Tag, Package, Project]]:
"""Get all tags referencing this artifact with package and project info."""
return (
self.db.query(Tag, Package, Project)
.join(Package, Tag.package_id == Package.id)
.join(Project, Package.project_id == Project.id)
.filter(Tag.artifact_id == artifact_id)
.all()
)
def search(self, query_str: str, limit: int = 10) -> List[Tuple[Tag, Artifact, str, str]]:
"""
Search artifacts by tag name or original filename.
Returns (tag, artifact, package_name, project_name) tuples.
"""
search_lower = query_str.lower()
return (
self.db.query(Tag, Artifact, Package.name, Project.name)
.join(Artifact, Tag.artifact_id == Artifact.id)
.join(Package, Tag.package_id == Package.id)
.join(Project, Package.project_id == Project.id)
.filter(
or_(
func.lower(Tag.name).contains(search_lower),
func.lower(Artifact.original_name).contains(search_lower)
)
)
.order_by(Tag.name)
.limit(limit)
.all()
)
def update_metadata(self, artifact: Artifact, metadata: dict) -> Artifact:
"""Update or merge format metadata."""
if artifact.format_metadata:
artifact.format_metadata = {**artifact.format_metadata, **metadata}
else:
artifact.format_metadata = metadata
self.db.flush()
return artifact

View File

@@ -0,0 +1,96 @@
"""
Base repository class with common CRUD operations.
"""
from typing import TypeVar, Generic, Type, Optional, List, Any, Dict
from sqlalchemy.orm import Session
from sqlalchemy import func, asc, desc
from uuid import UUID
from ..models import Base
T = TypeVar("T", bound=Base)
class BaseRepository(Generic[T]):
"""
Base repository providing common CRUD operations.
Subclasses should set `model` class attribute to the SQLAlchemy model.
"""
model: Type[T]
def __init__(self, db: Session):
self.db = db
def get_by_id(self, id: Any) -> Optional[T]:
"""Get entity by primary key."""
return self.db.query(self.model).filter(self.model.id == id).first()
def get_all(
self,
skip: int = 0,
limit: int = 100,
order_by: str = None,
order_desc: bool = False,
) -> List[T]:
"""Get all entities with pagination and optional ordering."""
query = self.db.query(self.model)
if order_by and hasattr(self.model, order_by):
column = getattr(self.model, order_by)
query = query.order_by(desc(column) if order_desc else asc(column))
return query.offset(skip).limit(limit).all()
def count(self) -> int:
"""Count total entities."""
return self.db.query(func.count(self.model.id)).scalar() or 0
def create(self, **kwargs) -> T:
"""Create a new entity."""
entity = self.model(**kwargs)
self.db.add(entity)
self.db.flush() # Flush to get ID without committing
return entity
def update(self, entity: T, **kwargs) -> T:
"""Update an existing entity."""
for key, value in kwargs.items():
if hasattr(entity, key):
setattr(entity, key, value)
self.db.flush()
return entity
def delete(self, entity: T) -> None:
"""Delete an entity."""
self.db.delete(entity)
self.db.flush()
def delete_by_id(self, id: Any) -> bool:
"""Delete entity by ID. Returns True if deleted, False if not found."""
entity = self.get_by_id(id)
if entity:
self.delete(entity)
return True
return False
def exists(self, id: Any) -> bool:
"""Check if entity exists by ID."""
return self.db.query(
self.db.query(self.model).filter(self.model.id == id).exists()
).scalar()
def commit(self) -> None:
"""Commit the current transaction."""
self.db.commit()
def rollback(self) -> None:
"""Rollback the current transaction."""
self.db.rollback()
def refresh(self, entity: T) -> T:
"""Refresh entity from database."""
self.db.refresh(entity)
return entity

View File

@@ -0,0 +1,177 @@
"""
Package repository for data access operations.
"""
from typing import Optional, List, Tuple
from sqlalchemy.orm import Session
from sqlalchemy import func, or_, asc, desc
from uuid import UUID
from .base import BaseRepository
from ..models import Package, Project, Tag, Upload, Artifact
class PackageRepository(BaseRepository[Package]):
"""Repository for Package entity operations."""
model = Package
def get_by_name(self, project_id: UUID, name: str) -> Optional[Package]:
"""Get package by name within a project."""
return (
self.db.query(Package)
.filter(Package.project_id == project_id, Package.name == name)
.first()
)
def get_by_project_and_name(self, project_name: str, package_name: str) -> Optional[Package]:
"""Get package by project name and package name."""
return (
self.db.query(Package)
.join(Project, Package.project_id == Project.id)
.filter(Project.name == project_name, Package.name == package_name)
.first()
)
def exists_by_name(self, project_id: UUID, name: str) -> bool:
"""Check if package with name exists in project."""
return self.db.query(
self.db.query(Package)
.filter(Package.project_id == project_id, Package.name == name)
.exists()
).scalar()
def list_by_project(
self,
project_id: UUID,
page: int = 1,
limit: int = 20,
search: Optional[str] = None,
format: Optional[str] = None,
platform: Optional[str] = None,
sort: str = "name",
order: str = "asc",
) -> Tuple[List[Package], int]:
"""
List packages in a project with filtering and pagination.
Returns tuple of (packages, total_count).
"""
query = self.db.query(Package).filter(Package.project_id == project_id)
# Apply search filter
if search:
search_lower = search.lower()
query = query.filter(
or_(
func.lower(Package.name).contains(search_lower),
func.lower(Package.description).contains(search_lower)
)
)
# Apply format filter
if format:
query = query.filter(Package.format == format)
# Apply platform filter
if platform:
query = query.filter(Package.platform == platform)
# Get total count
total = query.count()
# Apply sorting
sort_columns = {
"name": Package.name,
"created_at": Package.created_at,
"updated_at": Package.updated_at,
}
sort_column = sort_columns.get(sort, Package.name)
if order == "desc":
query = query.order_by(desc(sort_column))
else:
query = query.order_by(asc(sort_column))
# Apply pagination
offset = (page - 1) * limit
packages = query.offset(offset).limit(limit).all()
return packages, total
def create_package(
self,
project_id: UUID,
name: str,
description: Optional[str] = None,
format: str = "generic",
platform: str = "any",
) -> Package:
"""Create a new package."""
return self.create(
project_id=project_id,
name=name,
description=description,
format=format,
platform=platform,
)
def update_package(
self,
package: Package,
name: Optional[str] = None,
description: Optional[str] = None,
format: Optional[str] = None,
platform: Optional[str] = None,
) -> Package:
"""Update package fields."""
updates = {}
if name is not None:
updates["name"] = name
if description is not None:
updates["description"] = description
if format is not None:
updates["format"] = format
if platform is not None:
updates["platform"] = platform
return self.update(package, **updates)
def get_stats(self, package_id: UUID) -> dict:
"""Get package statistics (tag count, artifact count, total size)."""
tag_count = (
self.db.query(func.count(Tag.id))
.filter(Tag.package_id == package_id)
.scalar() or 0
)
artifact_stats = (
self.db.query(
func.count(func.distinct(Upload.artifact_id)),
func.coalesce(func.sum(Artifact.size), 0)
)
.join(Artifact, Upload.artifact_id == Artifact.id)
.filter(Upload.package_id == package_id)
.first()
)
return {
"tag_count": tag_count,
"artifact_count": artifact_stats[0] if artifact_stats else 0,
"total_size": artifact_stats[1] if artifact_stats else 0,
}
def search(self, query_str: str, limit: int = 10) -> List[Tuple[Package, str]]:
"""Search packages by name or description. Returns (package, project_name) tuples."""
search_lower = query_str.lower()
return (
self.db.query(Package, Project.name)
.join(Project, Package.project_id == Project.id)
.filter(
or_(
func.lower(Package.name).contains(search_lower),
func.lower(Package.description).contains(search_lower)
)
)
.order_by(Package.name)
.limit(limit)
.all()
)

View File

@@ -0,0 +1,132 @@
"""
Project repository for data access operations.
"""
from typing import Optional, List, Tuple
from sqlalchemy.orm import Session
from sqlalchemy import func, or_, asc, desc
from uuid import UUID
from .base import BaseRepository
from ..models import Project
class ProjectRepository(BaseRepository[Project]):
"""Repository for Project entity operations."""
model = Project
def get_by_name(self, name: str) -> Optional[Project]:
"""Get project by unique name."""
return self.db.query(Project).filter(Project.name == name).first()
def exists_by_name(self, name: str) -> bool:
"""Check if project with name exists."""
return self.db.query(
self.db.query(Project).filter(Project.name == name).exists()
).scalar()
def list_accessible(
self,
user_id: str,
page: int = 1,
limit: int = 20,
search: Optional[str] = None,
visibility: Optional[str] = None,
sort: str = "name",
order: str = "asc",
) -> Tuple[List[Project], int]:
"""
List projects accessible to user with filtering and pagination.
Returns tuple of (projects, total_count).
"""
# Base query - filter by access
query = self.db.query(Project).filter(
or_(Project.is_public == True, Project.created_by == user_id)
)
# Apply visibility filter
if visibility == "public":
query = query.filter(Project.is_public == True)
elif visibility == "private":
query = query.filter(Project.is_public == False, Project.created_by == user_id)
# Apply search filter
if search:
search_lower = search.lower()
query = query.filter(
or_(
func.lower(Project.name).contains(search_lower),
func.lower(Project.description).contains(search_lower)
)
)
# Get total count before pagination
total = query.count()
# Apply sorting
sort_columns = {
"name": Project.name,
"created_at": Project.created_at,
"updated_at": Project.updated_at,
}
sort_column = sort_columns.get(sort, Project.name)
if order == "desc":
query = query.order_by(desc(sort_column))
else:
query = query.order_by(asc(sort_column))
# Apply pagination
offset = (page - 1) * limit
projects = query.offset(offset).limit(limit).all()
return projects, total
def create_project(
self,
name: str,
created_by: str,
description: Optional[str] = None,
is_public: bool = True,
) -> Project:
"""Create a new project."""
return self.create(
name=name,
description=description,
is_public=is_public,
created_by=created_by,
)
def update_project(
self,
project: Project,
name: Optional[str] = None,
description: Optional[str] = None,
is_public: Optional[bool] = None,
) -> Project:
"""Update project fields."""
updates = {}
if name is not None:
updates["name"] = name
if description is not None:
updates["description"] = description
if is_public is not None:
updates["is_public"] = is_public
return self.update(project, **updates)
def search(self, query_str: str, limit: int = 10) -> List[Project]:
"""Search projects by name or description."""
search_lower = query_str.lower()
return (
self.db.query(Project)
.filter(
or_(
func.lower(Project.name).contains(search_lower),
func.lower(Project.description).contains(search_lower)
)
)
.order_by(Project.name)
.limit(limit)
.all()
)

View File

@@ -0,0 +1,168 @@
"""
Tag repository for data access operations.
"""
from typing import Optional, List, Tuple
from sqlalchemy.orm import Session
from sqlalchemy import func, or_, asc, desc
from uuid import UUID
from .base import BaseRepository
from ..models import Tag, TagHistory, Artifact, Package, Project
class TagRepository(BaseRepository[Tag]):
"""Repository for Tag entity operations."""
model = Tag
def get_by_name(self, package_id: UUID, name: str) -> Optional[Tag]:
"""Get tag by name within a package."""
return (
self.db.query(Tag)
.filter(Tag.package_id == package_id, Tag.name == name)
.first()
)
def get_with_artifact(self, package_id: UUID, name: str) -> Optional[Tuple[Tag, Artifact]]:
"""Get tag with its artifact."""
return (
self.db.query(Tag, Artifact)
.join(Artifact, Tag.artifact_id == Artifact.id)
.filter(Tag.package_id == package_id, Tag.name == name)
.first()
)
def exists_by_name(self, package_id: UUID, name: str) -> bool:
"""Check if tag with name exists in package."""
return self.db.query(
self.db.query(Tag)
.filter(Tag.package_id == package_id, Tag.name == name)
.exists()
).scalar()
def list_by_package(
self,
package_id: UUID,
page: int = 1,
limit: int = 20,
search: Optional[str] = None,
sort: str = "name",
order: str = "asc",
) -> Tuple[List[Tuple[Tag, Artifact]], int]:
"""
List tags in a package with artifact metadata.
Returns tuple of ((tag, artifact) tuples, total_count).
"""
query = (
self.db.query(Tag, Artifact)
.join(Artifact, Tag.artifact_id == Artifact.id)
.filter(Tag.package_id == package_id)
)
# Apply search filter (tag name or artifact original filename)
if search:
search_lower = search.lower()
query = query.filter(
or_(
func.lower(Tag.name).contains(search_lower),
func.lower(Artifact.original_name).contains(search_lower)
)
)
# Get total count
total = query.count()
# Apply sorting
sort_columns = {
"name": Tag.name,
"created_at": Tag.created_at,
}
sort_column = sort_columns.get(sort, Tag.name)
if order == "desc":
query = query.order_by(desc(sort_column))
else:
query = query.order_by(asc(sort_column))
# Apply pagination
offset = (page - 1) * limit
results = query.offset(offset).limit(limit).all()
return results, total
def create_tag(
self,
package_id: UUID,
name: str,
artifact_id: str,
created_by: str,
) -> Tag:
"""Create a new tag."""
return self.create(
package_id=package_id,
name=name,
artifact_id=artifact_id,
created_by=created_by,
)
def update_artifact(
self,
tag: Tag,
new_artifact_id: str,
changed_by: str,
record_history: bool = True,
) -> Tag:
"""
Update tag to point to a different artifact.
Optionally records change in tag history.
"""
old_artifact_id = tag.artifact_id
if record_history and old_artifact_id != new_artifact_id:
history = TagHistory(
tag_id=tag.id,
old_artifact_id=old_artifact_id,
new_artifact_id=new_artifact_id,
changed_by=changed_by,
)
self.db.add(history)
tag.artifact_id = new_artifact_id
tag.created_by = changed_by
self.db.flush()
return tag
def get_history(self, tag_id: UUID) -> List[TagHistory]:
"""Get tag change history."""
return (
self.db.query(TagHistory)
.filter(TagHistory.tag_id == tag_id)
.order_by(TagHistory.changed_at.desc())
.all()
)
def get_latest_in_package(self, package_id: UUID) -> Optional[Tag]:
"""Get the most recently created/updated tag in a package."""
return (
self.db.query(Tag)
.filter(Tag.package_id == package_id)
.order_by(Tag.created_at.desc())
.first()
)
def get_by_artifact(self, artifact_id: str) -> List[Tag]:
"""Get all tags pointing to an artifact."""
return (
self.db.query(Tag)
.filter(Tag.artifact_id == artifact_id)
.all()
)
def count_by_artifact(self, artifact_id: str) -> int:
"""Count tags pointing to an artifact."""
return (
self.db.query(func.count(Tag.id))
.filter(Tag.artifact_id == artifact_id)
.scalar() or 0
)

View File

@@ -0,0 +1,136 @@
"""
Upload repository for data access operations.
"""
from typing import Optional, List, Tuple
from datetime import datetime
from sqlalchemy.orm import Session
from sqlalchemy import func, desc
from uuid import UUID
from .base import BaseRepository
from ..models import Upload, Artifact, Package, Project
class UploadRepository(BaseRepository[Upload]):
"""Repository for Upload entity operations."""
model = Upload
def create_upload(
self,
artifact_id: str,
package_id: UUID,
uploaded_by: str,
original_name: Optional[str] = None,
source_ip: Optional[str] = None,
) -> Upload:
"""Record a new upload event."""
return self.create(
artifact_id=artifact_id,
package_id=package_id,
original_name=original_name,
uploaded_by=uploaded_by,
source_ip=source_ip,
)
def list_by_package(
self,
package_id: UUID,
page: int = 1,
limit: int = 20,
) -> Tuple[List[Upload], int]:
"""List uploads for a package with pagination."""
query = self.db.query(Upload).filter(Upload.package_id == package_id)
total = query.count()
offset = (page - 1) * limit
uploads = query.order_by(Upload.uploaded_at.desc()).offset(offset).limit(limit).all()
return uploads, total
def list_by_artifact(self, artifact_id: str) -> List[Upload]:
"""List all uploads of a specific artifact."""
return (
self.db.query(Upload)
.filter(Upload.artifact_id == artifact_id)
.order_by(Upload.uploaded_at.desc())
.all()
)
def get_latest_for_package(self, package_id: UUID) -> Optional[Upload]:
"""Get the most recent upload for a package."""
return (
self.db.query(Upload)
.filter(Upload.package_id == package_id)
.order_by(Upload.uploaded_at.desc())
.first()
)
def get_latest_timestamp(self, package_id: UUID) -> Optional[datetime]:
"""Get timestamp of most recent upload for a package."""
result = (
self.db.query(func.max(Upload.uploaded_at))
.filter(Upload.package_id == package_id)
.scalar()
)
return result
def count_by_artifact(self, artifact_id: str) -> int:
"""Count uploads of a specific artifact."""
return (
self.db.query(func.count(Upload.id))
.filter(Upload.artifact_id == artifact_id)
.scalar() or 0
)
def count_by_package(self, package_id: UUID) -> int:
"""Count total uploads for a package."""
return (
self.db.query(func.count(Upload.id))
.filter(Upload.package_id == package_id)
.scalar() or 0
)
def get_distinct_artifacts_count(self, package_id: UUID) -> int:
"""Count distinct artifacts uploaded to a package."""
return (
self.db.query(func.count(func.distinct(Upload.artifact_id)))
.filter(Upload.package_id == package_id)
.scalar() or 0
)
def get_uploads_by_user(
self,
user_id: str,
page: int = 1,
limit: int = 20,
) -> Tuple[List[Upload], int]:
"""List uploads by a specific user."""
query = self.db.query(Upload).filter(Upload.uploaded_by == user_id)
total = query.count()
offset = (page - 1) * limit
uploads = query.order_by(Upload.uploaded_at.desc()).offset(offset).limit(limit).all()
return uploads, total
def get_upload_stats(self, package_id: UUID) -> dict:
"""Get upload statistics for a package."""
stats = (
self.db.query(
func.count(Upload.id),
func.count(func.distinct(Upload.artifact_id)),
func.min(Upload.uploaded_at),
func.max(Upload.uploaded_at),
)
.filter(Upload.package_id == package_id)
.first()
)
return {
"total_uploads": stats[0] if stats else 0,
"unique_artifacts": stats[1] if stats else 0,
"first_upload": stats[2] if stats else None,
"last_upload": stats[3] if stats else None,
}

View File

@@ -0,0 +1,9 @@
"""
Service layer for business logic.
"""
from .artifact_cleanup import ArtifactCleanupService
__all__ = [
"ArtifactCleanupService",
]

View File

@@ -0,0 +1,181 @@
"""
Service for artifact reference counting and cleanup.
"""
from typing import List, Optional, Tuple
from sqlalchemy.orm import Session
import logging
from ..models import Artifact, Tag, Upload, Package
from ..repositories.artifact import ArtifactRepository
from ..repositories.tag import TagRepository
from ..storage import S3Storage
logger = logging.getLogger(__name__)
class ArtifactCleanupService:
"""
Service for managing artifact reference counts and cleaning up orphaned artifacts.
Reference counting rules:
- ref_count starts at 1 when artifact is first uploaded
- ref_count increments when the same artifact is uploaded again (deduplication)
- ref_count decrements when a tag is deleted or updated to point elsewhere
- ref_count decrements when a package is deleted (for each tag pointing to artifact)
- When ref_count reaches 0, artifact is a candidate for deletion from S3
"""
def __init__(self, db: Session, storage: Optional[S3Storage] = None):
self.db = db
self.storage = storage
self.artifact_repo = ArtifactRepository(db)
self.tag_repo = TagRepository(db)
def on_tag_deleted(self, artifact_id: str) -> Artifact:
"""
Called when a tag is deleted.
Decrements ref_count for the artifact the tag was pointing to.
"""
artifact = self.artifact_repo.get_by_sha256(artifact_id)
if artifact:
artifact = self.artifact_repo.decrement_ref_count(artifact)
logger.info(f"Decremented ref_count for artifact {artifact_id}: now {artifact.ref_count}")
return artifact
def on_tag_updated(self, old_artifact_id: str, new_artifact_id: str) -> Tuple[Optional[Artifact], Optional[Artifact]]:
"""
Called when a tag is updated to point to a different artifact.
Decrements ref_count for old artifact, increments for new (if different).
Returns (old_artifact, new_artifact) tuple.
"""
old_artifact = None
new_artifact = None
if old_artifact_id != new_artifact_id:
# Decrement old artifact ref_count
old_artifact = self.artifact_repo.get_by_sha256(old_artifact_id)
if old_artifact:
old_artifact = self.artifact_repo.decrement_ref_count(old_artifact)
logger.info(f"Decremented ref_count for old artifact {old_artifact_id}: now {old_artifact.ref_count}")
# Increment new artifact ref_count
new_artifact = self.artifact_repo.get_by_sha256(new_artifact_id)
if new_artifact:
new_artifact = self.artifact_repo.increment_ref_count(new_artifact)
logger.info(f"Incremented ref_count for new artifact {new_artifact_id}: now {new_artifact.ref_count}")
return old_artifact, new_artifact
def on_package_deleted(self, package_id) -> List[str]:
"""
Called when a package is deleted.
Decrements ref_count for all artifacts that had tags in the package.
Returns list of artifact IDs that were affected.
"""
# Get all tags in the package before deletion
tags = self.db.query(Tag).filter(Tag.package_id == package_id).all()
affected_artifacts = []
for tag in tags:
artifact = self.artifact_repo.get_by_sha256(tag.artifact_id)
if artifact:
self.artifact_repo.decrement_ref_count(artifact)
affected_artifacts.append(tag.artifact_id)
logger.info(f"Decremented ref_count for artifact {tag.artifact_id} (package delete)")
return affected_artifacts
def cleanup_orphaned_artifacts(self, batch_size: int = 100, dry_run: bool = False) -> List[str]:
"""
Find and delete artifacts with ref_count = 0.
Args:
batch_size: Maximum number of artifacts to process
dry_run: If True, only report what would be deleted without actually deleting
Returns:
List of artifact IDs that were (or would be) deleted
"""
orphaned = self.artifact_repo.get_orphaned_artifacts(limit=batch_size)
deleted_ids = []
for artifact in orphaned:
if dry_run:
logger.info(f"[DRY RUN] Would delete orphaned artifact: {artifact.id}")
deleted_ids.append(artifact.id)
else:
try:
# Delete from S3 first
if self.storage:
self.storage.delete(artifact.s3_key)
logger.info(f"Deleted artifact from S3: {artifact.s3_key}")
# Then delete from database
self.artifact_repo.delete(artifact)
deleted_ids.append(artifact.id)
logger.info(f"Deleted orphaned artifact from database: {artifact.id}")
except Exception as e:
logger.error(f"Failed to delete artifact {artifact.id}: {e}")
if not dry_run and deleted_ids:
self.db.commit()
return deleted_ids
def get_orphaned_count(self) -> int:
"""Get count of artifacts with ref_count = 0."""
from sqlalchemy import func
return (
self.db.query(func.count(Artifact.id))
.filter(Artifact.ref_count == 0)
.scalar() or 0
)
def verify_ref_counts(self, fix: bool = False) -> List[dict]:
"""
Verify that ref_counts match actual tag references.
Args:
fix: If True, fix any mismatched ref_counts
Returns:
List of artifacts with mismatched ref_counts
"""
from sqlalchemy import func
# Get actual tag counts per artifact
tag_counts = (
self.db.query(Tag.artifact_id, func.count(Tag.id).label("tag_count"))
.group_by(Tag.artifact_id)
.all()
)
tag_count_map = {artifact_id: count for artifact_id, count in tag_counts}
# Check all artifacts
artifacts = self.db.query(Artifact).all()
mismatches = []
for artifact in artifacts:
actual_count = tag_count_map.get(artifact.id, 0)
# ref_count should be at least 1 (initial upload) + additional uploads
# But tags are the primary reference, so we check against tag count
if artifact.ref_count < actual_count:
mismatch = {
"artifact_id": artifact.id,
"stored_ref_count": artifact.ref_count,
"actual_tag_count": actual_count,
}
mismatches.append(mismatch)
if fix:
artifact.ref_count = max(actual_count, 1)
logger.warning(f"Fixed ref_count for artifact {artifact.id}: {mismatch['stored_ref_count']} -> {artifact.ref_count}")
if fix and mismatches:
self.db.commit()
return mismatches