from sqlalchemy import create_engine, text, event from sqlalchemy.orm import sessionmaker, Session from sqlalchemy.pool import QueuePool from typing import Generator from contextlib import contextmanager import logging import time from .config import get_settings from .models import Base settings = get_settings() logger = logging.getLogger(__name__) # Create engine with connection pool configuration engine = create_engine( settings.database_url, pool_pre_ping=True, # Check connection health before using poolclass=QueuePool, pool_size=settings.database_pool_size, max_overflow=settings.database_max_overflow, pool_timeout=settings.database_pool_timeout, pool_recycle=settings.database_pool_recycle, ) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) # Connection pool monitoring @event.listens_for(engine, "checkout") def receive_checkout(dbapi_connection, connection_record, connection_proxy): """Log when a connection is checked out from the pool""" logger.debug(f"Connection checked out from pool: {id(dbapi_connection)}") @event.listens_for(engine, "checkin") def receive_checkin(dbapi_connection, connection_record): """Log when a connection is returned to the pool""" logger.debug(f"Connection returned to pool: {id(dbapi_connection)}") def get_pool_status() -> dict: """Get current connection pool status for monitoring""" pool = engine.pool return { "pool_size": pool.size(), "checked_out": pool.checkedout(), "overflow": pool.overflow(), "checked_in": pool.checkedin(), } def init_db(): """Create all tables and run migrations""" Base.metadata.create_all(bind=engine) # Run migrations for schema updates _run_migrations() def _run_migrations(): """Run manual migrations for schema updates""" migrations = [ # Add format_metadata column to artifacts table """ DO $$ BEGIN IF NOT EXISTS ( SELECT 1 FROM information_schema.columns WHERE table_name = 'artifacts' AND column_name = 'format_metadata' ) THEN ALTER TABLE artifacts ADD COLUMN format_metadata JSONB DEFAULT '{}'; END IF; END $$; """, # Add format column to packages table """ DO $$ BEGIN IF NOT EXISTS ( SELECT 1 FROM information_schema.columns WHERE table_name = 'packages' AND column_name = 'format' ) THEN ALTER TABLE packages ADD COLUMN format VARCHAR(50) DEFAULT 'generic' NOT NULL; CREATE INDEX IF NOT EXISTS idx_packages_format ON packages(format); END IF; END $$; """, # Add platform column to packages table """ DO $$ BEGIN IF NOT EXISTS ( SELECT 1 FROM information_schema.columns WHERE table_name = 'packages' AND column_name = 'platform' ) THEN ALTER TABLE packages ADD COLUMN platform VARCHAR(50) DEFAULT 'any' NOT NULL; CREATE INDEX IF NOT EXISTS idx_packages_platform ON packages(platform); END IF; END $$; """, # Add ref_count index and constraints for artifacts """ DO $$ BEGIN -- Add ref_count index IF NOT EXISTS ( SELECT 1 FROM pg_indexes WHERE indexname = 'idx_artifacts_ref_count' ) THEN CREATE INDEX idx_artifacts_ref_count ON artifacts(ref_count); END IF; -- Add ref_count >= 0 constraint IF NOT EXISTS ( SELECT 1 FROM pg_constraint WHERE conname = 'check_ref_count_non_negative' ) THEN ALTER TABLE artifacts ADD CONSTRAINT check_ref_count_non_negative CHECK (ref_count >= 0); END IF; END $$; """, # Add composite indexes for packages and tags """ DO $$ BEGIN -- Composite index for package lookup by project and name IF NOT EXISTS ( SELECT 1 FROM pg_indexes WHERE indexname = 'idx_packages_project_name' ) THEN CREATE UNIQUE INDEX idx_packages_project_name ON packages(project_id, name); END IF; -- Composite index for tag lookup by package and name IF NOT EXISTS ( SELECT 1 FROM pg_indexes WHERE indexname = 'idx_tags_package_name' ) THEN CREATE UNIQUE INDEX idx_tags_package_name ON tags(package_id, name); END IF; -- Composite index for recent tags queries IF NOT EXISTS ( SELECT 1 FROM pg_indexes WHERE indexname = 'idx_tags_package_created_at' ) THEN CREATE INDEX idx_tags_package_created_at ON tags(package_id, created_at); END IF; END $$; """, ] with engine.connect() as conn: for migration in migrations: try: conn.execute(text(migration)) conn.commit() except Exception as e: logger.warning(f"Migration failed (may already be applied): {e}") def get_db() -> Generator[Session, None, None]: """Dependency for getting database sessions""" db = SessionLocal() try: yield db finally: db.close() @contextmanager def transaction(db: Session): """ Context manager for explicit transaction management with savepoint support. Usage: with transaction(db): # operations here # automatically commits on success, rolls back on exception """ try: yield db db.commit() except Exception: db.rollback() raise @contextmanager def savepoint(db: Session, name: str = None): """ Create a savepoint for partial rollback support. Usage: with savepoint(db, "my_savepoint"): # operations here # rolls back to savepoint on exception, but doesn't rollback whole transaction """ savepoint_obj = db.begin_nested() try: yield savepoint_obj savepoint_obj.commit() except Exception: savepoint_obj.rollback() raise def retry_on_deadlock(func, max_retries: int = 3, delay: float = 0.1): """ Decorator/wrapper to retry operations on deadlock detection. Usage: @retry_on_deadlock def my_operation(db): ... Or: retry_on_deadlock(lambda: my_operation(db))() """ import functools from sqlalchemy.exc import OperationalError @functools.wraps(func) def wrapper(*args, **kwargs): last_exception = None for attempt in range(max_retries): try: return func(*args, **kwargs) except OperationalError as e: # Check for deadlock error codes (PostgreSQL: 40P01, MySQL: 1213) error_str = str(e).lower() if "deadlock" in error_str or "40p01" in error_str: last_exception = e logger.warning(f"Deadlock detected, retrying (attempt {attempt + 1}/{max_retries})") time.sleep(delay * (attempt + 1)) # Exponential backoff else: raise raise last_exception return wrapper