Use raw string (r"") for SQL containing regex pattern to avoid Python SyntaxWarning about invalid escape sequence '\.'
322 lines
11 KiB
Python
322 lines
11 KiB
Python
from sqlalchemy import create_engine, text, event
|
|
from sqlalchemy.orm import sessionmaker, Session
|
|
from sqlalchemy.pool import QueuePool
|
|
from typing import Generator
|
|
from contextlib import contextmanager
|
|
import logging
|
|
import time
|
|
|
|
from .config import get_settings
|
|
from .models import Base
|
|
|
|
settings = get_settings()
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Build connect_args with query timeout if configured
|
|
connect_args = {}
|
|
if settings.database_query_timeout > 0:
|
|
# PostgreSQL statement_timeout is in milliseconds
|
|
connect_args["options"] = f"-c statement_timeout={settings.database_query_timeout * 1000}"
|
|
|
|
# Create engine with connection pool configuration
|
|
engine = create_engine(
|
|
settings.database_url,
|
|
pool_pre_ping=True, # Check connection health before using
|
|
poolclass=QueuePool,
|
|
pool_size=settings.database_pool_size,
|
|
max_overflow=settings.database_max_overflow,
|
|
pool_timeout=settings.database_pool_timeout,
|
|
pool_recycle=settings.database_pool_recycle,
|
|
connect_args=connect_args,
|
|
)
|
|
|
|
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
|
|
|
|
|
# Connection pool monitoring
|
|
@event.listens_for(engine, "checkout")
|
|
def receive_checkout(dbapi_connection, connection_record, connection_proxy):
|
|
"""Log when a connection is checked out from the pool"""
|
|
logger.debug(f"Connection checked out from pool: {id(dbapi_connection)}")
|
|
|
|
|
|
@event.listens_for(engine, "checkin")
|
|
def receive_checkin(dbapi_connection, connection_record):
|
|
"""Log when a connection is returned to the pool"""
|
|
logger.debug(f"Connection returned to pool: {id(dbapi_connection)}")
|
|
|
|
|
|
def get_pool_status() -> dict:
|
|
"""Get current connection pool status for monitoring"""
|
|
pool = engine.pool
|
|
return {
|
|
"pool_size": pool.size(),
|
|
"checked_out": pool.checkedout(),
|
|
"overflow": pool.overflow(),
|
|
"checked_in": pool.checkedin(),
|
|
}
|
|
|
|
|
|
def init_db():
|
|
"""Create all tables and run migrations"""
|
|
Base.metadata.create_all(bind=engine)
|
|
|
|
# Run migrations for schema updates
|
|
_run_migrations()
|
|
|
|
|
|
def _run_migrations():
|
|
"""Run manual migrations for schema updates"""
|
|
migrations = [
|
|
# Add format_metadata column to artifacts table
|
|
"""
|
|
DO $$
|
|
BEGIN
|
|
IF NOT EXISTS (
|
|
SELECT 1 FROM information_schema.columns
|
|
WHERE table_name = 'artifacts' AND column_name = 'format_metadata'
|
|
) THEN
|
|
ALTER TABLE artifacts ADD COLUMN format_metadata JSONB DEFAULT '{}';
|
|
END IF;
|
|
END $$;
|
|
""",
|
|
# Add format column to packages table
|
|
"""
|
|
DO $$
|
|
BEGIN
|
|
IF NOT EXISTS (
|
|
SELECT 1 FROM information_schema.columns
|
|
WHERE table_name = 'packages' AND column_name = 'format'
|
|
) THEN
|
|
ALTER TABLE packages ADD COLUMN format VARCHAR(50) DEFAULT 'generic' NOT NULL;
|
|
CREATE INDEX IF NOT EXISTS idx_packages_format ON packages(format);
|
|
END IF;
|
|
END $$;
|
|
""",
|
|
# Add platform column to packages table
|
|
"""
|
|
DO $$
|
|
BEGIN
|
|
IF NOT EXISTS (
|
|
SELECT 1 FROM information_schema.columns
|
|
WHERE table_name = 'packages' AND column_name = 'platform'
|
|
) THEN
|
|
ALTER TABLE packages ADD COLUMN platform VARCHAR(50) DEFAULT 'any' NOT NULL;
|
|
CREATE INDEX IF NOT EXISTS idx_packages_platform ON packages(platform);
|
|
END IF;
|
|
END $$;
|
|
""",
|
|
# Add ref_count index and constraints for artifacts
|
|
"""
|
|
DO $$
|
|
BEGIN
|
|
-- Add ref_count index
|
|
IF NOT EXISTS (
|
|
SELECT 1 FROM pg_indexes WHERE indexname = 'idx_artifacts_ref_count'
|
|
) THEN
|
|
CREATE INDEX idx_artifacts_ref_count ON artifacts(ref_count);
|
|
END IF;
|
|
|
|
-- Add ref_count >= 0 constraint
|
|
IF NOT EXISTS (
|
|
SELECT 1 FROM pg_constraint WHERE conname = 'check_ref_count_non_negative'
|
|
) THEN
|
|
ALTER TABLE artifacts ADD CONSTRAINT check_ref_count_non_negative CHECK (ref_count >= 0);
|
|
END IF;
|
|
END $$;
|
|
""",
|
|
# Add composite indexes for packages and tags
|
|
"""
|
|
DO $$
|
|
BEGIN
|
|
-- Composite index for package lookup by project and name
|
|
IF NOT EXISTS (
|
|
SELECT 1 FROM pg_indexes WHERE indexname = 'idx_packages_project_name'
|
|
) THEN
|
|
CREATE UNIQUE INDEX idx_packages_project_name ON packages(project_id, name);
|
|
END IF;
|
|
|
|
-- Composite index for tag lookup by package and name
|
|
IF NOT EXISTS (
|
|
SELECT 1 FROM pg_indexes WHERE indexname = 'idx_tags_package_name'
|
|
) THEN
|
|
CREATE UNIQUE INDEX idx_tags_package_name ON tags(package_id, name);
|
|
END IF;
|
|
|
|
-- Composite index for recent tags queries
|
|
IF NOT EXISTS (
|
|
SELECT 1 FROM pg_indexes WHERE indexname = 'idx_tags_package_created_at'
|
|
) THEN
|
|
CREATE INDEX idx_tags_package_created_at ON tags(package_id, created_at);
|
|
END IF;
|
|
END $$;
|
|
""",
|
|
# Add package_versions indexes and triggers (007_package_versions.sql)
|
|
"""
|
|
DO $$
|
|
BEGIN
|
|
-- Create indexes for package_versions if table exists
|
|
IF EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = 'package_versions') THEN
|
|
-- Indexes for common queries
|
|
IF NOT EXISTS (SELECT 1 FROM pg_indexes WHERE indexname = 'idx_package_versions_package_id') THEN
|
|
CREATE INDEX idx_package_versions_package_id ON package_versions(package_id);
|
|
END IF;
|
|
IF NOT EXISTS (SELECT 1 FROM pg_indexes WHERE indexname = 'idx_package_versions_artifact_id') THEN
|
|
CREATE INDEX idx_package_versions_artifact_id ON package_versions(artifact_id);
|
|
END IF;
|
|
IF NOT EXISTS (SELECT 1 FROM pg_indexes WHERE indexname = 'idx_package_versions_package_version') THEN
|
|
CREATE INDEX idx_package_versions_package_version ON package_versions(package_id, version);
|
|
END IF;
|
|
END IF;
|
|
END $$;
|
|
""",
|
|
# Create ref_count trigger functions for package_versions
|
|
"""
|
|
CREATE OR REPLACE FUNCTION increment_version_ref_count()
|
|
RETURNS TRIGGER AS $$
|
|
BEGIN
|
|
UPDATE artifacts SET ref_count = ref_count + 1 WHERE id = NEW.artifact_id;
|
|
RETURN NEW;
|
|
END;
|
|
$$ LANGUAGE plpgsql;
|
|
""",
|
|
"""
|
|
CREATE OR REPLACE FUNCTION decrement_version_ref_count()
|
|
RETURNS TRIGGER AS $$
|
|
BEGIN
|
|
UPDATE artifacts SET ref_count = ref_count - 1 WHERE id = OLD.artifact_id;
|
|
RETURN OLD;
|
|
END;
|
|
$$ LANGUAGE plpgsql;
|
|
""",
|
|
# Create triggers for package_versions ref_count
|
|
"""
|
|
DO $$
|
|
BEGIN
|
|
IF EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = 'package_versions') THEN
|
|
-- Drop and recreate triggers to ensure they're current
|
|
DROP TRIGGER IF EXISTS package_versions_ref_count_insert ON package_versions;
|
|
CREATE TRIGGER package_versions_ref_count_insert
|
|
AFTER INSERT ON package_versions
|
|
FOR EACH ROW
|
|
EXECUTE FUNCTION increment_version_ref_count();
|
|
|
|
DROP TRIGGER IF EXISTS package_versions_ref_count_delete ON package_versions;
|
|
CREATE TRIGGER package_versions_ref_count_delete
|
|
AFTER DELETE ON package_versions
|
|
FOR EACH ROW
|
|
EXECUTE FUNCTION decrement_version_ref_count();
|
|
END IF;
|
|
END $$;
|
|
""",
|
|
# Migrate existing semver tags to package_versions
|
|
r"""
|
|
DO $$
|
|
BEGIN
|
|
IF EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = 'package_versions') THEN
|
|
-- Migrate tags that look like versions (v1.0.0, 1.2.3, 2.0.0-beta, etc.)
|
|
INSERT INTO package_versions (package_id, artifact_id, version, version_source, created_by, created_at)
|
|
SELECT
|
|
t.package_id,
|
|
t.artifact_id,
|
|
CASE WHEN t.name LIKE 'v%' THEN substring(t.name from 2) ELSE t.name END,
|
|
'migrated_from_tag',
|
|
t.created_by,
|
|
t.created_at
|
|
FROM tags t
|
|
WHERE t.name ~ '^v?[0-9]+\.[0-9]+(\.[0-9]+)?([-.][a-zA-Z0-9]+)?$'
|
|
ON CONFLICT (package_id, version) DO NOTHING;
|
|
END IF;
|
|
END $$;
|
|
""",
|
|
]
|
|
|
|
with engine.connect() as conn:
|
|
for migration in migrations:
|
|
try:
|
|
conn.execute(text(migration))
|
|
conn.commit()
|
|
except Exception as e:
|
|
logger.warning(f"Migration failed (may already be applied): {e}")
|
|
|
|
|
|
def get_db() -> Generator[Session, None, None]:
|
|
"""Dependency for getting database sessions"""
|
|
db = SessionLocal()
|
|
try:
|
|
yield db
|
|
finally:
|
|
db.close()
|
|
|
|
|
|
@contextmanager
|
|
def transaction(db: Session):
|
|
"""
|
|
Context manager for explicit transaction management with savepoint support.
|
|
|
|
Usage:
|
|
with transaction(db):
|
|
# operations here
|
|
# automatically commits on success, rolls back on exception
|
|
"""
|
|
try:
|
|
yield db
|
|
db.commit()
|
|
except Exception:
|
|
db.rollback()
|
|
raise
|
|
|
|
|
|
@contextmanager
|
|
def savepoint(db: Session, name: str = None):
|
|
"""
|
|
Create a savepoint for partial rollback support.
|
|
|
|
Usage:
|
|
with savepoint(db, "my_savepoint"):
|
|
# operations here
|
|
# rolls back to savepoint on exception, but doesn't rollback whole transaction
|
|
"""
|
|
savepoint_obj = db.begin_nested()
|
|
try:
|
|
yield savepoint_obj
|
|
savepoint_obj.commit()
|
|
except Exception:
|
|
savepoint_obj.rollback()
|
|
raise
|
|
|
|
|
|
def retry_on_deadlock(func, max_retries: int = 3, delay: float = 0.1):
|
|
"""
|
|
Decorator/wrapper to retry operations on deadlock detection.
|
|
|
|
Usage:
|
|
@retry_on_deadlock
|
|
def my_operation(db):
|
|
...
|
|
|
|
Or:
|
|
retry_on_deadlock(lambda: my_operation(db))()
|
|
"""
|
|
import functools
|
|
from sqlalchemy.exc import OperationalError
|
|
|
|
@functools.wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
last_exception = None
|
|
for attempt in range(max_retries):
|
|
try:
|
|
return func(*args, **kwargs)
|
|
except OperationalError as e:
|
|
# Check for deadlock error codes (PostgreSQL: 40P01, MySQL: 1213)
|
|
error_str = str(e).lower()
|
|
if "deadlock" in error_str or "40p01" in error_str:
|
|
last_exception = e
|
|
logger.warning(f"Deadlock detected, retrying (attempt {attempt + 1}/{max_retries})")
|
|
time.sleep(delay * (attempt + 1)) # Exponential backoff
|
|
else:
|
|
raise
|
|
raise last_exception
|
|
|
|
return wrapper
|