- Add require_admin authentication to cache management endpoints - Add limit validation (1-500) on failed tasks query - Add thread lock for worker pool thread safety - Fix exception handling with separate recovery DB session - Remove obsolete design doc
583 lines
18 KiB
Python
583 lines
18 KiB
Python
"""
|
|
PyPI cache worker module.
|
|
|
|
Manages a thread pool for background caching of PyPI packages and their dependencies.
|
|
Replaces unbounded thread spawning with a managed queue-based approach.
|
|
"""
|
|
|
|
import logging
|
|
import re
|
|
import threading
|
|
import time
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from datetime import datetime, timedelta
|
|
from typing import List, Optional
|
|
from uuid import UUID
|
|
|
|
import httpx
|
|
from sqlalchemy import or_
|
|
from sqlalchemy.orm import Session
|
|
|
|
from .config import get_settings
|
|
|
|
settings = get_settings()
|
|
from .database import SessionLocal
|
|
from .models import PyPICacheTask, Package, Project, Tag
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Module-level worker pool state with lock for thread safety
|
|
_worker_lock = threading.Lock()
|
|
_cache_worker_pool: Optional[ThreadPoolExecutor] = None
|
|
_cache_worker_running: bool = False
|
|
_dispatcher_thread: Optional[threading.Thread] = None
|
|
|
|
|
|
def init_cache_worker_pool(max_workers: Optional[int] = None):
|
|
"""
|
|
Initialize the cache worker pool. Called on app startup.
|
|
|
|
Args:
|
|
max_workers: Number of concurrent workers. Defaults to PYPI_CACHE_WORKERS setting.
|
|
"""
|
|
global _cache_worker_pool, _cache_worker_running, _dispatcher_thread
|
|
|
|
with _worker_lock:
|
|
if _cache_worker_pool is not None:
|
|
logger.warning("Cache worker pool already initialized")
|
|
return
|
|
|
|
workers = max_workers or settings.PYPI_CACHE_WORKERS
|
|
_cache_worker_pool = ThreadPoolExecutor(
|
|
max_workers=workers,
|
|
thread_name_prefix="pypi-cache-",
|
|
)
|
|
_cache_worker_running = True
|
|
|
|
# Start the dispatcher thread
|
|
_dispatcher_thread = threading.Thread(
|
|
target=_cache_dispatcher_loop,
|
|
daemon=True,
|
|
name="pypi-cache-dispatcher",
|
|
)
|
|
_dispatcher_thread.start()
|
|
|
|
logger.info(f"PyPI cache worker pool initialized with {workers} workers")
|
|
|
|
|
|
def shutdown_cache_worker_pool(wait: bool = True, timeout: float = 30.0):
|
|
"""
|
|
Shutdown the cache worker pool gracefully.
|
|
|
|
Args:
|
|
wait: Whether to wait for pending tasks to complete.
|
|
timeout: Maximum time to wait for shutdown.
|
|
"""
|
|
global _cache_worker_pool, _cache_worker_running, _dispatcher_thread
|
|
|
|
with _worker_lock:
|
|
if _cache_worker_pool is None:
|
|
return
|
|
|
|
logger.info("Shutting down PyPI cache worker pool...")
|
|
_cache_worker_running = False
|
|
|
|
# Wait for dispatcher to stop (outside lock to avoid deadlock)
|
|
if _dispatcher_thread and _dispatcher_thread.is_alive():
|
|
_dispatcher_thread.join(timeout=5.0)
|
|
|
|
with _worker_lock:
|
|
# Shutdown thread pool
|
|
if _cache_worker_pool:
|
|
_cache_worker_pool.shutdown(wait=wait, cancel_futures=not wait)
|
|
_cache_worker_pool = None
|
|
_dispatcher_thread = None
|
|
|
|
logger.info("PyPI cache worker pool shut down")
|
|
|
|
|
|
def _cache_dispatcher_loop():
|
|
"""
|
|
Main dispatcher loop: poll DB for pending tasks and submit to worker pool.
|
|
"""
|
|
logger.info("PyPI cache dispatcher started")
|
|
|
|
while _cache_worker_running:
|
|
try:
|
|
db = SessionLocal()
|
|
try:
|
|
tasks = _get_ready_tasks(db, limit=10)
|
|
|
|
for task in tasks:
|
|
# Mark in_progress before submitting
|
|
task.status = "in_progress"
|
|
task.started_at = datetime.utcnow()
|
|
db.commit()
|
|
|
|
# Submit to worker pool
|
|
_cache_worker_pool.submit(_process_cache_task, task.id)
|
|
|
|
# Sleep if no work (avoid busy loop)
|
|
if not tasks:
|
|
time.sleep(2.0)
|
|
else:
|
|
# Small delay between batches to avoid overwhelming
|
|
time.sleep(0.1)
|
|
|
|
finally:
|
|
db.close()
|
|
|
|
except Exception as e:
|
|
logger.error(f"PyPI cache dispatcher error: {e}")
|
|
time.sleep(5.0)
|
|
|
|
logger.info("PyPI cache dispatcher stopped")
|
|
|
|
|
|
def _get_ready_tasks(db: Session, limit: int = 10) -> List[PyPICacheTask]:
|
|
"""
|
|
Get tasks ready to process.
|
|
|
|
Returns pending tasks that are either new or ready for retry.
|
|
Orders by depth (shallow first) then creation time (FIFO).
|
|
"""
|
|
now = datetime.utcnow()
|
|
return (
|
|
db.query(PyPICacheTask)
|
|
.filter(
|
|
PyPICacheTask.status == "pending",
|
|
or_(
|
|
PyPICacheTask.next_retry_at == None, # New tasks
|
|
PyPICacheTask.next_retry_at <= now, # Retry tasks ready
|
|
),
|
|
)
|
|
.order_by(
|
|
PyPICacheTask.depth.asc(), # Prefer shallow deps first
|
|
PyPICacheTask.created_at.asc(), # FIFO within same depth
|
|
)
|
|
.limit(limit)
|
|
.all()
|
|
)
|
|
|
|
|
|
def _process_cache_task(task_id: UUID):
|
|
"""
|
|
Process a single cache task. Called by worker pool.
|
|
|
|
Args:
|
|
task_id: The ID of the task to process.
|
|
"""
|
|
db = SessionLocal()
|
|
try:
|
|
task = db.query(PyPICacheTask).filter(PyPICacheTask.id == task_id).first()
|
|
if not task:
|
|
logger.warning(f"PyPI cache task {task_id} not found")
|
|
return
|
|
|
|
logger.info(
|
|
f"Processing cache task: {task.package_name} "
|
|
f"(depth={task.depth}, attempt={task.attempts + 1})"
|
|
)
|
|
|
|
# Check if already cached by another task (dedup)
|
|
existing_artifact = _find_cached_package(db, task.package_name)
|
|
if existing_artifact:
|
|
logger.info(f"Package {task.package_name} already cached, skipping")
|
|
_mark_task_completed(db, task, cached_artifact_id=existing_artifact)
|
|
return
|
|
|
|
# Check depth limit
|
|
max_depth = settings.PYPI_CACHE_MAX_DEPTH
|
|
if task.depth >= max_depth:
|
|
_mark_task_failed(db, task, f"Max depth {max_depth} exceeded")
|
|
return
|
|
|
|
# Do the actual caching
|
|
result = _fetch_and_cache_package(task.package_name, task.version_constraint)
|
|
|
|
if result["success"]:
|
|
_mark_task_completed(db, task, cached_artifact_id=result.get("artifact_id"))
|
|
logger.info(f"Successfully cached {task.package_name}")
|
|
else:
|
|
_handle_task_failure(db, task, result["error"])
|
|
|
|
except Exception as e:
|
|
logger.exception(f"Error processing cache task {task_id}")
|
|
# Use a fresh session for error handling to avoid transaction issues
|
|
recovery_db = SessionLocal()
|
|
try:
|
|
task = recovery_db.query(PyPICacheTask).filter(PyPICacheTask.id == task_id).first()
|
|
if task:
|
|
_handle_task_failure(recovery_db, task, str(e))
|
|
finally:
|
|
recovery_db.close()
|
|
finally:
|
|
db.close()
|
|
|
|
|
|
def _find_cached_package(db: Session, package_name: str) -> Optional[str]:
|
|
"""
|
|
Check if a package is already cached.
|
|
|
|
Args:
|
|
db: Database session.
|
|
package_name: Normalized package name.
|
|
|
|
Returns:
|
|
Artifact ID if cached, None otherwise.
|
|
"""
|
|
# Normalize package name (PEP 503)
|
|
normalized = re.sub(r"[-_.]+", "-", package_name).lower()
|
|
|
|
# Check if _pypi project has this package with at least one tag
|
|
system_project = db.query(Project).filter(Project.name == "_pypi").first()
|
|
if not system_project:
|
|
return None
|
|
|
|
package = (
|
|
db.query(Package)
|
|
.filter(
|
|
Package.project_id == system_project.id,
|
|
Package.name == normalized,
|
|
)
|
|
.first()
|
|
)
|
|
if not package:
|
|
return None
|
|
|
|
# Check if package has any tags (cached files)
|
|
tag = db.query(Tag).filter(Tag.package_id == package.id).first()
|
|
if tag:
|
|
return tag.artifact_id
|
|
|
|
return None
|
|
|
|
|
|
def _fetch_and_cache_package(
|
|
package_name: str,
|
|
version_constraint: Optional[str] = None,
|
|
) -> dict:
|
|
"""
|
|
Fetch and cache a PyPI package by making requests through our own proxy.
|
|
|
|
Args:
|
|
package_name: The package name to cache.
|
|
version_constraint: Optional version constraint (currently not used for selection).
|
|
|
|
Returns:
|
|
Dict with "success" bool, "artifact_id" on success, "error" on failure.
|
|
"""
|
|
# Normalize package name (PEP 503)
|
|
normalized_name = re.sub(r"[-_.]+", "-", package_name).lower()
|
|
|
|
# Build the URL to our own proxy
|
|
# Use localhost since we're making internal requests
|
|
base_url = f"http://localhost:{settings.PORT}"
|
|
|
|
try:
|
|
with httpx.Client(timeout=60.0, follow_redirects=True) as client:
|
|
# Step 1: Get the simple index page
|
|
simple_url = f"{base_url}/pypi/simple/{normalized_name}/"
|
|
logger.debug(f"Fetching index: {simple_url}")
|
|
|
|
response = client.get(simple_url)
|
|
if response.status_code == 404:
|
|
return {"success": False, "error": f"Package {package_name} not found on upstream"}
|
|
if response.status_code != 200:
|
|
return {"success": False, "error": f"Failed to get index: HTTP {response.status_code}"}
|
|
|
|
# Step 2: Parse HTML to find downloadable files
|
|
html = response.text
|
|
|
|
# Create pattern that matches both normalized (hyphens) and original (underscores)
|
|
name_pattern = re.sub(r"[-_]+", "[-_]+", normalized_name)
|
|
|
|
# Look for wheel files first (preferred)
|
|
wheel_pattern = rf'href="([^"]*{name_pattern}[^"]*\.whl[^"]*)"'
|
|
matches = re.findall(wheel_pattern, html, re.IGNORECASE)
|
|
|
|
if not matches:
|
|
# Fall back to sdist
|
|
sdist_pattern = rf'href="([^"]*{name_pattern}[^"]*\.tar\.gz[^"]*)"'
|
|
matches = re.findall(sdist_pattern, html, re.IGNORECASE)
|
|
|
|
if not matches:
|
|
logger.warning(
|
|
f"No downloadable files found for {package_name}. "
|
|
f"Pattern: {wheel_pattern}, HTML preview: {html[:500]}"
|
|
)
|
|
return {"success": False, "error": "No downloadable files found"}
|
|
|
|
# Get the last match (usually latest version)
|
|
download_url = matches[-1]
|
|
|
|
# Make URL absolute if needed
|
|
if download_url.startswith("/"):
|
|
download_url = f"{base_url}{download_url}"
|
|
elif not download_url.startswith("http"):
|
|
download_url = f"{base_url}/pypi/simple/{normalized_name}/{download_url}"
|
|
|
|
# Step 3: Download the file through our proxy (this caches it)
|
|
logger.debug(f"Downloading: {download_url}")
|
|
response = client.get(download_url)
|
|
|
|
if response.status_code != 200:
|
|
return {"success": False, "error": f"Download failed: HTTP {response.status_code}"}
|
|
|
|
# Get artifact ID from response header
|
|
artifact_id = response.headers.get("X-Checksum-SHA256")
|
|
|
|
return {"success": True, "artifact_id": artifact_id}
|
|
|
|
except httpx.TimeoutException as e:
|
|
return {"success": False, "error": f"Timeout: {e}"}
|
|
except httpx.ConnectError as e:
|
|
return {"success": False, "error": f"Connection failed: {e}"}
|
|
except Exception as e:
|
|
return {"success": False, "error": str(e)}
|
|
|
|
|
|
def _mark_task_completed(
|
|
db: Session,
|
|
task: PyPICacheTask,
|
|
cached_artifact_id: Optional[str] = None,
|
|
):
|
|
"""Mark a task as completed."""
|
|
task.status = "completed"
|
|
task.completed_at = datetime.utcnow()
|
|
task.cached_artifact_id = cached_artifact_id
|
|
task.error_message = None
|
|
db.commit()
|
|
|
|
|
|
def _mark_task_failed(db: Session, task: PyPICacheTask, error: str):
|
|
"""Mark a task as permanently failed."""
|
|
task.status = "failed"
|
|
task.completed_at = datetime.utcnow()
|
|
task.error_message = error[:1000] if error else None
|
|
db.commit()
|
|
logger.warning(f"PyPI cache task failed permanently: {task.package_name} - {error}")
|
|
|
|
|
|
def _handle_task_failure(db: Session, task: PyPICacheTask, error: str):
|
|
"""
|
|
Handle a failed cache attempt with exponential backoff.
|
|
|
|
Args:
|
|
db: Database session.
|
|
task: The failed task.
|
|
error: Error message.
|
|
"""
|
|
task.attempts += 1
|
|
task.error_message = error[:1000] if error else None
|
|
|
|
max_attempts = task.max_attempts or settings.PYPI_CACHE_MAX_ATTEMPTS
|
|
|
|
if task.attempts >= max_attempts:
|
|
# Give up after max attempts
|
|
task.status = "failed"
|
|
task.completed_at = datetime.utcnow()
|
|
logger.warning(
|
|
f"PyPI cache task failed permanently: {task.package_name} - {error} "
|
|
f"(after {task.attempts} attempts)"
|
|
)
|
|
else:
|
|
# Schedule retry with exponential backoff
|
|
# Attempt 1 failed → retry in 30s
|
|
# Attempt 2 failed → retry in 60s
|
|
# Attempt 3 failed → permanent failure (if max_attempts=3)
|
|
backoff_seconds = 30 * (2 ** (task.attempts - 1))
|
|
task.status = "pending"
|
|
task.next_retry_at = datetime.utcnow() + timedelta(seconds=backoff_seconds)
|
|
logger.info(
|
|
f"PyPI cache task will retry: {task.package_name} in {backoff_seconds}s "
|
|
f"(attempt {task.attempts}/{max_attempts})"
|
|
)
|
|
|
|
db.commit()
|
|
|
|
|
|
def enqueue_cache_task(
|
|
db: Session,
|
|
package_name: str,
|
|
version_constraint: Optional[str] = None,
|
|
parent_task_id: Optional[UUID] = None,
|
|
depth: int = 0,
|
|
triggered_by_artifact: Optional[str] = None,
|
|
) -> Optional[PyPICacheTask]:
|
|
"""
|
|
Enqueue a package for caching.
|
|
|
|
Performs deduplication: won't create a task if one already exists
|
|
for the same package in pending/in_progress state, or if the package
|
|
is already cached.
|
|
|
|
Args:
|
|
db: Database session.
|
|
package_name: The package name to cache.
|
|
version_constraint: Optional version constraint.
|
|
parent_task_id: Parent task that spawned this one.
|
|
depth: Recursion depth.
|
|
triggered_by_artifact: Artifact that declared this dependency.
|
|
|
|
Returns:
|
|
The created or existing task, or None if already cached.
|
|
"""
|
|
# Normalize package name (PEP 503)
|
|
normalized = re.sub(r"[-_.]+", "-", package_name).lower()
|
|
|
|
# Check for existing pending/in_progress task
|
|
existing_task = (
|
|
db.query(PyPICacheTask)
|
|
.filter(
|
|
PyPICacheTask.package_name == normalized,
|
|
PyPICacheTask.status.in_(["pending", "in_progress"]),
|
|
)
|
|
.first()
|
|
)
|
|
if existing_task:
|
|
logger.debug(f"Task already exists for {normalized}: {existing_task.id}")
|
|
return existing_task
|
|
|
|
# Check if already cached
|
|
if _find_cached_package(db, normalized):
|
|
logger.debug(f"Package {normalized} already cached, skipping task creation")
|
|
return None
|
|
|
|
# Create new task
|
|
task = PyPICacheTask(
|
|
package_name=normalized,
|
|
version_constraint=version_constraint,
|
|
parent_task_id=parent_task_id,
|
|
depth=depth,
|
|
triggered_by_artifact=triggered_by_artifact,
|
|
max_attempts=settings.PYPI_CACHE_MAX_ATTEMPTS,
|
|
)
|
|
db.add(task)
|
|
db.flush()
|
|
|
|
logger.info(f"Enqueued cache task for {normalized} (depth={depth})")
|
|
return task
|
|
|
|
|
|
def get_cache_status(db: Session) -> dict:
|
|
"""
|
|
Get summary of cache task queue status.
|
|
|
|
Returns:
|
|
Dict with counts by status.
|
|
"""
|
|
from sqlalchemy import func
|
|
|
|
stats = (
|
|
db.query(PyPICacheTask.status, func.count(PyPICacheTask.id))
|
|
.group_by(PyPICacheTask.status)
|
|
.all()
|
|
)
|
|
|
|
return {
|
|
"pending": next((s[1] for s in stats if s[0] == "pending"), 0),
|
|
"in_progress": next((s[1] for s in stats if s[0] == "in_progress"), 0),
|
|
"completed": next((s[1] for s in stats if s[0] == "completed"), 0),
|
|
"failed": next((s[1] for s in stats if s[0] == "failed"), 0),
|
|
}
|
|
|
|
|
|
def get_failed_tasks(db: Session, limit: int = 50) -> List[dict]:
|
|
"""
|
|
Get list of failed tasks for debugging.
|
|
|
|
Args:
|
|
db: Database session.
|
|
limit: Maximum number of tasks to return.
|
|
|
|
Returns:
|
|
List of failed task info dicts.
|
|
"""
|
|
tasks = (
|
|
db.query(PyPICacheTask)
|
|
.filter(PyPICacheTask.status == "failed")
|
|
.order_by(PyPICacheTask.completed_at.desc())
|
|
.limit(limit)
|
|
.all()
|
|
)
|
|
|
|
return [
|
|
{
|
|
"id": str(task.id),
|
|
"package": task.package_name,
|
|
"error": task.error_message,
|
|
"attempts": task.attempts,
|
|
"depth": task.depth,
|
|
"failed_at": task.completed_at.isoformat() if task.completed_at else None,
|
|
}
|
|
for task in tasks
|
|
]
|
|
|
|
|
|
def retry_failed_task(db: Session, package_name: str) -> Optional[PyPICacheTask]:
|
|
"""
|
|
Reset a failed task to retry.
|
|
|
|
Args:
|
|
db: Database session.
|
|
package_name: The package name to retry.
|
|
|
|
Returns:
|
|
The reset task, or None if not found.
|
|
"""
|
|
normalized = re.sub(r"[-_.]+", "-", package_name).lower()
|
|
|
|
task = (
|
|
db.query(PyPICacheTask)
|
|
.filter(
|
|
PyPICacheTask.package_name == normalized,
|
|
PyPICacheTask.status == "failed",
|
|
)
|
|
.first()
|
|
)
|
|
|
|
if not task:
|
|
return None
|
|
|
|
task.status = "pending"
|
|
task.attempts = 0
|
|
task.next_retry_at = None
|
|
task.error_message = None
|
|
task.started_at = None
|
|
task.completed_at = None
|
|
db.commit()
|
|
|
|
logger.info(f"Reset failed task for retry: {normalized}")
|
|
return task
|
|
|
|
|
|
def retry_all_failed_tasks(db: Session) -> int:
|
|
"""
|
|
Reset all failed tasks to retry.
|
|
|
|
Args:
|
|
db: Database session.
|
|
|
|
Returns:
|
|
Number of tasks reset.
|
|
"""
|
|
count = (
|
|
db.query(PyPICacheTask)
|
|
.filter(PyPICacheTask.status == "failed")
|
|
.update(
|
|
{
|
|
"status": "pending",
|
|
"attempts": 0,
|
|
"next_retry_at": None,
|
|
"error_message": None,
|
|
"started_at": None,
|
|
"completed_at": None,
|
|
}
|
|
)
|
|
)
|
|
db.commit()
|
|
|
|
logger.info(f"Reset {count} failed tasks for retry")
|
|
return count
|