""" PyPI cache worker module. Manages a thread pool for background caching of PyPI packages and their dependencies. Replaces unbounded thread spawning with a managed queue-based approach. """ import logging import re import threading import time from concurrent.futures import ThreadPoolExecutor from datetime import datetime, timedelta from typing import List, Optional from uuid import UUID import httpx from sqlalchemy import or_ from sqlalchemy.orm import Session from .config import get_settings settings = get_settings() from .database import SessionLocal from .models import PyPICacheTask, Package, Project, Tag logger = logging.getLogger(__name__) # Module-level worker pool state with lock for thread safety _worker_lock = threading.Lock() _cache_worker_pool: Optional[ThreadPoolExecutor] = None _cache_worker_running: bool = False _dispatcher_thread: Optional[threading.Thread] = None def init_cache_worker_pool(max_workers: Optional[int] = None): """ Initialize the cache worker pool. Called on app startup. Args: max_workers: Number of concurrent workers. Defaults to PYPI_CACHE_WORKERS setting. """ global _cache_worker_pool, _cache_worker_running, _dispatcher_thread with _worker_lock: if _cache_worker_pool is not None: logger.warning("Cache worker pool already initialized") return workers = max_workers or settings.PYPI_CACHE_WORKERS _cache_worker_pool = ThreadPoolExecutor( max_workers=workers, thread_name_prefix="pypi-cache-", ) _cache_worker_running = True # Start the dispatcher thread _dispatcher_thread = threading.Thread( target=_cache_dispatcher_loop, daemon=True, name="pypi-cache-dispatcher", ) _dispatcher_thread.start() logger.info(f"PyPI cache worker pool initialized with {workers} workers") def shutdown_cache_worker_pool(wait: bool = True, timeout: float = 30.0): """ Shutdown the cache worker pool gracefully. Args: wait: Whether to wait for pending tasks to complete. timeout: Maximum time to wait for shutdown. """ global _cache_worker_pool, _cache_worker_running, _dispatcher_thread with _worker_lock: if _cache_worker_pool is None: return logger.info("Shutting down PyPI cache worker pool...") _cache_worker_running = False # Wait for dispatcher to stop (outside lock to avoid deadlock) if _dispatcher_thread and _dispatcher_thread.is_alive(): _dispatcher_thread.join(timeout=5.0) with _worker_lock: # Shutdown thread pool if _cache_worker_pool: _cache_worker_pool.shutdown(wait=wait, cancel_futures=not wait) _cache_worker_pool = None _dispatcher_thread = None logger.info("PyPI cache worker pool shut down") def _cache_dispatcher_loop(): """ Main dispatcher loop: poll DB for pending tasks and submit to worker pool. """ logger.info("PyPI cache dispatcher started") while _cache_worker_running: try: db = SessionLocal() try: tasks = _get_ready_tasks(db, limit=10) for task in tasks: # Mark in_progress before submitting task.status = "in_progress" task.started_at = datetime.utcnow() db.commit() # Submit to worker pool _cache_worker_pool.submit(_process_cache_task, task.id) # Sleep if no work (avoid busy loop) if not tasks: time.sleep(2.0) else: # Small delay between batches to avoid overwhelming time.sleep(0.1) finally: db.close() except Exception as e: logger.error(f"PyPI cache dispatcher error: {e}") time.sleep(5.0) logger.info("PyPI cache dispatcher stopped") def _get_ready_tasks(db: Session, limit: int = 10) -> List[PyPICacheTask]: """ Get tasks ready to process. Returns pending tasks that are either new or ready for retry. Orders by depth (shallow first) then creation time (FIFO). """ now = datetime.utcnow() return ( db.query(PyPICacheTask) .filter( PyPICacheTask.status == "pending", or_( PyPICacheTask.next_retry_at == None, # New tasks PyPICacheTask.next_retry_at <= now, # Retry tasks ready ), ) .order_by( PyPICacheTask.depth.asc(), # Prefer shallow deps first PyPICacheTask.created_at.asc(), # FIFO within same depth ) .limit(limit) .all() ) def _process_cache_task(task_id: UUID): """ Process a single cache task. Called by worker pool. Args: task_id: The ID of the task to process. """ db = SessionLocal() try: task = db.query(PyPICacheTask).filter(PyPICacheTask.id == task_id).first() if not task: logger.warning(f"PyPI cache task {task_id} not found") return logger.info( f"Processing cache task: {task.package_name} " f"(depth={task.depth}, attempt={task.attempts + 1})" ) # Check if already cached by another task (dedup) existing_artifact = _find_cached_package(db, task.package_name) if existing_artifact: logger.info(f"Package {task.package_name} already cached, skipping") _mark_task_completed(db, task, cached_artifact_id=existing_artifact) return # Check depth limit max_depth = settings.PYPI_CACHE_MAX_DEPTH if task.depth >= max_depth: _mark_task_failed(db, task, f"Max depth {max_depth} exceeded") return # Do the actual caching result = _fetch_and_cache_package(task.package_name, task.version_constraint) if result["success"]: _mark_task_completed(db, task, cached_artifact_id=result.get("artifact_id")) logger.info(f"Successfully cached {task.package_name}") else: _handle_task_failure(db, task, result["error"]) except Exception as e: logger.exception(f"Error processing cache task {task_id}") # Use a fresh session for error handling to avoid transaction issues recovery_db = SessionLocal() try: task = recovery_db.query(PyPICacheTask).filter(PyPICacheTask.id == task_id).first() if task: _handle_task_failure(recovery_db, task, str(e)) finally: recovery_db.close() finally: db.close() def _find_cached_package(db: Session, package_name: str) -> Optional[str]: """ Check if a package is already cached. Args: db: Database session. package_name: Normalized package name. Returns: Artifact ID if cached, None otherwise. """ # Normalize package name (PEP 503) normalized = re.sub(r"[-_.]+", "-", package_name).lower() # Check if _pypi project has this package with at least one tag system_project = db.query(Project).filter(Project.name == "_pypi").first() if not system_project: return None package = ( db.query(Package) .filter( Package.project_id == system_project.id, Package.name == normalized, ) .first() ) if not package: return None # Check if package has any tags (cached files) tag = db.query(Tag).filter(Tag.package_id == package.id).first() if tag: return tag.artifact_id return None def _fetch_and_cache_package( package_name: str, version_constraint: Optional[str] = None, ) -> dict: """ Fetch and cache a PyPI package by making requests through our own proxy. Args: package_name: The package name to cache. version_constraint: Optional version constraint (currently not used for selection). Returns: Dict with "success" bool, "artifact_id" on success, "error" on failure. """ # Normalize package name (PEP 503) normalized_name = re.sub(r"[-_.]+", "-", package_name).lower() # Build the URL to our own proxy # Use localhost since we're making internal requests base_url = f"http://localhost:{settings.PORT}" try: with httpx.Client(timeout=60.0, follow_redirects=True) as client: # Step 1: Get the simple index page simple_url = f"{base_url}/pypi/simple/{normalized_name}/" logger.debug(f"Fetching index: {simple_url}") response = client.get(simple_url) if response.status_code == 404: return {"success": False, "error": f"Package {package_name} not found on upstream"} if response.status_code != 200: return {"success": False, "error": f"Failed to get index: HTTP {response.status_code}"} # Step 2: Parse HTML to find downloadable files html = response.text # Create pattern that matches both normalized (hyphens) and original (underscores) name_pattern = re.sub(r"[-_]+", "[-_]+", normalized_name) # Look for wheel files first (preferred) wheel_pattern = rf'href="([^"]*{name_pattern}[^"]*\.whl[^"]*)"' matches = re.findall(wheel_pattern, html, re.IGNORECASE) if not matches: # Fall back to sdist sdist_pattern = rf'href="([^"]*{name_pattern}[^"]*\.tar\.gz[^"]*)"' matches = re.findall(sdist_pattern, html, re.IGNORECASE) if not matches: logger.warning( f"No downloadable files found for {package_name}. " f"Pattern: {wheel_pattern}, HTML preview: {html[:500]}" ) return {"success": False, "error": "No downloadable files found"} # Get the last match (usually latest version) download_url = matches[-1] # Make URL absolute if needed if download_url.startswith("/"): download_url = f"{base_url}{download_url}" elif not download_url.startswith("http"): download_url = f"{base_url}/pypi/simple/{normalized_name}/{download_url}" # Step 3: Download the file through our proxy (this caches it) logger.debug(f"Downloading: {download_url}") response = client.get(download_url) if response.status_code != 200: return {"success": False, "error": f"Download failed: HTTP {response.status_code}"} # Get artifact ID from response header artifact_id = response.headers.get("X-Checksum-SHA256") return {"success": True, "artifact_id": artifact_id} except httpx.TimeoutException as e: return {"success": False, "error": f"Timeout: {e}"} except httpx.ConnectError as e: return {"success": False, "error": f"Connection failed: {e}"} except Exception as e: return {"success": False, "error": str(e)} def _mark_task_completed( db: Session, task: PyPICacheTask, cached_artifact_id: Optional[str] = None, ): """Mark a task as completed.""" task.status = "completed" task.completed_at = datetime.utcnow() task.cached_artifact_id = cached_artifact_id task.error_message = None db.commit() def _mark_task_failed(db: Session, task: PyPICacheTask, error: str): """Mark a task as permanently failed.""" task.status = "failed" task.completed_at = datetime.utcnow() task.error_message = error[:1000] if error else None db.commit() logger.warning(f"PyPI cache task failed permanently: {task.package_name} - {error}") def _handle_task_failure(db: Session, task: PyPICacheTask, error: str): """ Handle a failed cache attempt with exponential backoff. Args: db: Database session. task: The failed task. error: Error message. """ task.attempts += 1 task.error_message = error[:1000] if error else None max_attempts = task.max_attempts or settings.PYPI_CACHE_MAX_ATTEMPTS if task.attempts >= max_attempts: # Give up after max attempts task.status = "failed" task.completed_at = datetime.utcnow() logger.warning( f"PyPI cache task failed permanently: {task.package_name} - {error} " f"(after {task.attempts} attempts)" ) else: # Schedule retry with exponential backoff # Attempt 1 failed → retry in 30s # Attempt 2 failed → retry in 60s # Attempt 3 failed → permanent failure (if max_attempts=3) backoff_seconds = 30 * (2 ** (task.attempts - 1)) task.status = "pending" task.next_retry_at = datetime.utcnow() + timedelta(seconds=backoff_seconds) logger.info( f"PyPI cache task will retry: {task.package_name} in {backoff_seconds}s " f"(attempt {task.attempts}/{max_attempts})" ) db.commit() def enqueue_cache_task( db: Session, package_name: str, version_constraint: Optional[str] = None, parent_task_id: Optional[UUID] = None, depth: int = 0, triggered_by_artifact: Optional[str] = None, ) -> Optional[PyPICacheTask]: """ Enqueue a package for caching. Performs deduplication: won't create a task if one already exists for the same package in pending/in_progress state, or if the package is already cached. Args: db: Database session. package_name: The package name to cache. version_constraint: Optional version constraint. parent_task_id: Parent task that spawned this one. depth: Recursion depth. triggered_by_artifact: Artifact that declared this dependency. Returns: The created or existing task, or None if already cached. """ # Normalize package name (PEP 503) normalized = re.sub(r"[-_.]+", "-", package_name).lower() # Check for existing pending/in_progress task existing_task = ( db.query(PyPICacheTask) .filter( PyPICacheTask.package_name == normalized, PyPICacheTask.status.in_(["pending", "in_progress"]), ) .first() ) if existing_task: logger.debug(f"Task already exists for {normalized}: {existing_task.id}") return existing_task # Check if already cached if _find_cached_package(db, normalized): logger.debug(f"Package {normalized} already cached, skipping task creation") return None # Create new task task = PyPICacheTask( package_name=normalized, version_constraint=version_constraint, parent_task_id=parent_task_id, depth=depth, triggered_by_artifact=triggered_by_artifact, max_attempts=settings.PYPI_CACHE_MAX_ATTEMPTS, ) db.add(task) db.flush() logger.info(f"Enqueued cache task for {normalized} (depth={depth})") return task def get_cache_status(db: Session) -> dict: """ Get summary of cache task queue status. Returns: Dict with counts by status. """ from sqlalchemy import func stats = ( db.query(PyPICacheTask.status, func.count(PyPICacheTask.id)) .group_by(PyPICacheTask.status) .all() ) return { "pending": next((s[1] for s in stats if s[0] == "pending"), 0), "in_progress": next((s[1] for s in stats if s[0] == "in_progress"), 0), "completed": next((s[1] for s in stats if s[0] == "completed"), 0), "failed": next((s[1] for s in stats if s[0] == "failed"), 0), } def get_failed_tasks(db: Session, limit: int = 50) -> List[dict]: """ Get list of failed tasks for debugging. Args: db: Database session. limit: Maximum number of tasks to return. Returns: List of failed task info dicts. """ tasks = ( db.query(PyPICacheTask) .filter(PyPICacheTask.status == "failed") .order_by(PyPICacheTask.completed_at.desc()) .limit(limit) .all() ) return [ { "id": str(task.id), "package": task.package_name, "error": task.error_message, "attempts": task.attempts, "depth": task.depth, "failed_at": task.completed_at.isoformat() if task.completed_at else None, } for task in tasks ] def retry_failed_task(db: Session, package_name: str) -> Optional[PyPICacheTask]: """ Reset a failed task to retry. Args: db: Database session. package_name: The package name to retry. Returns: The reset task, or None if not found. """ normalized = re.sub(r"[-_.]+", "-", package_name).lower() task = ( db.query(PyPICacheTask) .filter( PyPICacheTask.package_name == normalized, PyPICacheTask.status == "failed", ) .first() ) if not task: return None task.status = "pending" task.attempts = 0 task.next_retry_at = None task.error_message = None task.started_at = None task.completed_at = None db.commit() logger.info(f"Reset failed task for retry: {normalized}") return task def retry_all_failed_tasks(db: Session) -> int: """ Reset all failed tasks to retry. Args: db: Database session. Returns: Number of tasks reset. """ count = ( db.query(PyPICacheTask) .filter(PyPICacheTask.status == "failed") .update( { "status": "pending", "attempts": 0, "next_retry_at": None, "error_message": None, "started_at": None, "completed_at": None, } ) ) db.commit() logger.info(f"Reset {count} failed tasks for retry") return count