Files
orchard/backend/app/pypi_cache_worker.py
Mondo Diaz a485852a6f Add Active Workers table to Background Jobs dashboard
Shows currently processing cache tasks in a dynamic table with:
- Package name and version constraint being cached
- Recursion depth and attempt number
- Start timestamp
- Pulsing indicator to show live activity

Backend changes:
- Add get_active_tasks() function to pypi_cache_worker.py
- Add GET /pypi/cache/active endpoint to pypi_proxy.py

Frontend changes:
- Add PyPICacheActiveTask type
- Add getPyPICacheActiveTasks() API function
- Add Active Workers section with animated table
- Auto-refreshes every 5 seconds with existing data
2026-02-05 09:15:08 -06:00

626 lines
19 KiB
Python

"""
PyPI cache worker module.
Manages a thread pool for background caching of PyPI packages and their dependencies.
Replaces unbounded thread spawning with a managed queue-based approach.
"""
import logging
import re
import threading
import time
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime, timedelta
from typing import List, Optional
from uuid import UUID
import httpx
from sqlalchemy import or_
from sqlalchemy.orm import Session
from .config import get_settings
settings = get_settings()
from .database import SessionLocal
from .models import PyPICacheTask, Package, Project, Tag
logger = logging.getLogger(__name__)
# Module-level worker pool state with lock for thread safety
_worker_lock = threading.Lock()
_cache_worker_pool: Optional[ThreadPoolExecutor] = None
_cache_worker_running: bool = False
_dispatcher_thread: Optional[threading.Thread] = None
def init_cache_worker_pool(max_workers: Optional[int] = None):
"""
Initialize the cache worker pool. Called on app startup.
Args:
max_workers: Number of concurrent workers. Defaults to PYPI_CACHE_WORKERS setting.
"""
global _cache_worker_pool, _cache_worker_running, _dispatcher_thread
with _worker_lock:
if _cache_worker_pool is not None:
logger.warning("Cache worker pool already initialized")
return
workers = max_workers or settings.PYPI_CACHE_WORKERS
_cache_worker_pool = ThreadPoolExecutor(
max_workers=workers,
thread_name_prefix="pypi-cache-",
)
_cache_worker_running = True
# Start the dispatcher thread
_dispatcher_thread = threading.Thread(
target=_cache_dispatcher_loop,
daemon=True,
name="pypi-cache-dispatcher",
)
_dispatcher_thread.start()
logger.info(f"PyPI cache worker pool initialized with {workers} workers")
def shutdown_cache_worker_pool(wait: bool = True, timeout: float = 30.0):
"""
Shutdown the cache worker pool gracefully.
Args:
wait: Whether to wait for pending tasks to complete.
timeout: Maximum time to wait for shutdown.
"""
global _cache_worker_pool, _cache_worker_running, _dispatcher_thread
with _worker_lock:
if _cache_worker_pool is None:
return
logger.info("Shutting down PyPI cache worker pool...")
_cache_worker_running = False
# Wait for dispatcher to stop (outside lock to avoid deadlock)
if _dispatcher_thread and _dispatcher_thread.is_alive():
_dispatcher_thread.join(timeout=5.0)
with _worker_lock:
# Shutdown thread pool
if _cache_worker_pool:
_cache_worker_pool.shutdown(wait=wait, cancel_futures=not wait)
_cache_worker_pool = None
_dispatcher_thread = None
logger.info("PyPI cache worker pool shut down")
def _cache_dispatcher_loop():
"""
Main dispatcher loop: poll DB for pending tasks and submit to worker pool.
"""
logger.info("PyPI cache dispatcher started")
while _cache_worker_running:
try:
db = SessionLocal()
try:
tasks = _get_ready_tasks(db, limit=10)
for task in tasks:
# Mark in_progress before submitting
task.status = "in_progress"
task.started_at = datetime.utcnow()
db.commit()
# Submit to worker pool
_cache_worker_pool.submit(_process_cache_task, task.id)
# Sleep if no work (avoid busy loop)
if not tasks:
time.sleep(2.0)
else:
# Small delay between batches to avoid overwhelming
time.sleep(0.1)
finally:
db.close()
except Exception as e:
logger.error(f"PyPI cache dispatcher error: {e}")
time.sleep(5.0)
logger.info("PyPI cache dispatcher stopped")
def _get_ready_tasks(db: Session, limit: int = 10) -> List[PyPICacheTask]:
"""
Get tasks ready to process.
Returns pending tasks that are either new or ready for retry.
Orders by depth (shallow first) then creation time (FIFO).
"""
now = datetime.utcnow()
return (
db.query(PyPICacheTask)
.filter(
PyPICacheTask.status == "pending",
or_(
PyPICacheTask.next_retry_at == None, # New tasks
PyPICacheTask.next_retry_at <= now, # Retry tasks ready
),
)
.order_by(
PyPICacheTask.depth.asc(), # Prefer shallow deps first
PyPICacheTask.created_at.asc(), # FIFO within same depth
)
.limit(limit)
.all()
)
def _process_cache_task(task_id: UUID):
"""
Process a single cache task. Called by worker pool.
Args:
task_id: The ID of the task to process.
"""
db = SessionLocal()
try:
task = db.query(PyPICacheTask).filter(PyPICacheTask.id == task_id).first()
if not task:
logger.warning(f"PyPI cache task {task_id} not found")
return
logger.info(
f"Processing cache task: {task.package_name} "
f"(depth={task.depth}, attempt={task.attempts + 1})"
)
# Check if already cached by another task (dedup)
existing_artifact = _find_cached_package(db, task.package_name)
if existing_artifact:
logger.info(f"Package {task.package_name} already cached, skipping")
_mark_task_completed(db, task, cached_artifact_id=existing_artifact)
return
# Check depth limit
max_depth = settings.PYPI_CACHE_MAX_DEPTH
if task.depth >= max_depth:
_mark_task_failed(db, task, f"Max depth {max_depth} exceeded")
return
# Do the actual caching - pass depth so nested deps are queued at depth+1
result = _fetch_and_cache_package(task.package_name, task.version_constraint, depth=task.depth)
if result["success"]:
_mark_task_completed(db, task, cached_artifact_id=result.get("artifact_id"))
logger.info(f"Successfully cached {task.package_name}")
else:
_handle_task_failure(db, task, result["error"])
except Exception as e:
logger.exception(f"Error processing cache task {task_id}")
# Use a fresh session for error handling to avoid transaction issues
recovery_db = SessionLocal()
try:
task = recovery_db.query(PyPICacheTask).filter(PyPICacheTask.id == task_id).first()
if task:
_handle_task_failure(recovery_db, task, str(e))
finally:
recovery_db.close()
finally:
db.close()
def _find_cached_package(db: Session, package_name: str) -> Optional[str]:
"""
Check if a package is already cached.
Args:
db: Database session.
package_name: Normalized package name.
Returns:
Artifact ID if cached, None otherwise.
"""
# Normalize package name (PEP 503)
normalized = re.sub(r"[-_.]+", "-", package_name).lower()
# Check if _pypi project has this package with at least one tag
system_project = db.query(Project).filter(Project.name == "_pypi").first()
if not system_project:
return None
package = (
db.query(Package)
.filter(
Package.project_id == system_project.id,
Package.name == normalized,
)
.first()
)
if not package:
return None
# Check if package has any tags (cached files)
tag = db.query(Tag).filter(Tag.package_id == package.id).first()
if tag:
return tag.artifact_id
return None
def _fetch_and_cache_package(
package_name: str,
version_constraint: Optional[str] = None,
depth: int = 0,
) -> dict:
"""
Fetch and cache a PyPI package by making requests through our own proxy.
Args:
package_name: The package name to cache.
version_constraint: Optional version constraint (currently not used for selection).
depth: Current recursion depth for dependency tracking.
Returns:
Dict with "success" bool, "artifact_id" on success, "error" on failure.
"""
# Normalize package name (PEP 503)
normalized_name = re.sub(r"[-_.]+", "-", package_name).lower()
# Build the URL to our own proxy
# Use localhost since we're making internal requests
base_url = f"http://localhost:{settings.PORT}"
try:
with httpx.Client(timeout=60.0, follow_redirects=True) as client:
# Step 1: Get the simple index page
simple_url = f"{base_url}/pypi/simple/{normalized_name}/"
logger.debug(f"Fetching index: {simple_url}")
response = client.get(simple_url)
if response.status_code == 404:
return {"success": False, "error": f"Package {package_name} not found on upstream"}
if response.status_code != 200:
return {"success": False, "error": f"Failed to get index: HTTP {response.status_code}"}
# Step 2: Parse HTML to find downloadable files
html = response.text
# Create pattern that matches both normalized (hyphens) and original (underscores)
name_pattern = re.sub(r"[-_]+", "[-_]+", normalized_name)
# Look for wheel files first (preferred)
wheel_pattern = rf'href="([^"]*{name_pattern}[^"]*\.whl[^"]*)"'
matches = re.findall(wheel_pattern, html, re.IGNORECASE)
if not matches:
# Fall back to sdist
sdist_pattern = rf'href="([^"]*{name_pattern}[^"]*\.tar\.gz[^"]*)"'
matches = re.findall(sdist_pattern, html, re.IGNORECASE)
if not matches:
logger.warning(
f"No downloadable files found for {package_name}. "
f"Pattern: {wheel_pattern}, HTML preview: {html[:500]}"
)
return {"success": False, "error": "No downloadable files found"}
# Get the last match (usually latest version)
download_url = matches[-1]
# Make URL absolute if needed
if download_url.startswith("/"):
download_url = f"{base_url}{download_url}"
elif not download_url.startswith("http"):
download_url = f"{base_url}/pypi/simple/{normalized_name}/{download_url}"
# Add cache-depth query parameter to track recursion depth
# The proxy will queue dependencies at depth+1
separator = "&" if "?" in download_url else "?"
download_url = f"{download_url}{separator}cache-depth={depth}"
# Step 3: Download the file through our proxy (this caches it)
logger.debug(f"Downloading: {download_url}")
response = client.get(download_url)
if response.status_code != 200:
return {"success": False, "error": f"Download failed: HTTP {response.status_code}"}
# Get artifact ID from response header
artifact_id = response.headers.get("X-Checksum-SHA256")
return {"success": True, "artifact_id": artifact_id}
except httpx.TimeoutException as e:
return {"success": False, "error": f"Timeout: {e}"}
except httpx.ConnectError as e:
return {"success": False, "error": f"Connection failed: {e}"}
except Exception as e:
return {"success": False, "error": str(e)}
# Alias for backward compatibility and clearer naming
_fetch_and_cache_package_with_depth = _fetch_and_cache_package
def _mark_task_completed(
db: Session,
task: PyPICacheTask,
cached_artifact_id: Optional[str] = None,
):
"""Mark a task as completed."""
task.status = "completed"
task.completed_at = datetime.utcnow()
task.cached_artifact_id = cached_artifact_id
task.error_message = None
db.commit()
def _mark_task_failed(db: Session, task: PyPICacheTask, error: str):
"""Mark a task as permanently failed."""
task.status = "failed"
task.completed_at = datetime.utcnow()
task.error_message = error[:1000] if error else None
db.commit()
logger.warning(f"PyPI cache task failed permanently: {task.package_name} - {error}")
def _handle_task_failure(db: Session, task: PyPICacheTask, error: str):
"""
Handle a failed cache attempt with exponential backoff.
Args:
db: Database session.
task: The failed task.
error: Error message.
"""
task.attempts += 1
task.error_message = error[:1000] if error else None
max_attempts = task.max_attempts or settings.PYPI_CACHE_MAX_ATTEMPTS
if task.attempts >= max_attempts:
# Give up after max attempts
task.status = "failed"
task.completed_at = datetime.utcnow()
logger.warning(
f"PyPI cache task failed permanently: {task.package_name} - {error} "
f"(after {task.attempts} attempts)"
)
else:
# Schedule retry with exponential backoff
# Attempt 1 failed → retry in 30s
# Attempt 2 failed → retry in 60s
# Attempt 3 failed → permanent failure (if max_attempts=3)
backoff_seconds = 30 * (2 ** (task.attempts - 1))
task.status = "pending"
task.next_retry_at = datetime.utcnow() + timedelta(seconds=backoff_seconds)
logger.info(
f"PyPI cache task will retry: {task.package_name} in {backoff_seconds}s "
f"(attempt {task.attempts}/{max_attempts})"
)
db.commit()
def enqueue_cache_task(
db: Session,
package_name: str,
version_constraint: Optional[str] = None,
parent_task_id: Optional[UUID] = None,
depth: int = 0,
triggered_by_artifact: Optional[str] = None,
) -> Optional[PyPICacheTask]:
"""
Enqueue a package for caching.
Performs deduplication: won't create a task if one already exists
for the same package in pending/in_progress state, or if the package
is already cached.
Args:
db: Database session.
package_name: The package name to cache.
version_constraint: Optional version constraint.
parent_task_id: Parent task that spawned this one.
depth: Recursion depth.
triggered_by_artifact: Artifact that declared this dependency.
Returns:
The created or existing task, or None if already cached.
"""
# Normalize package name (PEP 503)
normalized = re.sub(r"[-_.]+", "-", package_name).lower()
# Check for existing pending/in_progress task
existing_task = (
db.query(PyPICacheTask)
.filter(
PyPICacheTask.package_name == normalized,
PyPICacheTask.status.in_(["pending", "in_progress"]),
)
.first()
)
if existing_task:
logger.debug(f"Task already exists for {normalized}: {existing_task.id}")
return existing_task
# Check if already cached
if _find_cached_package(db, normalized):
logger.debug(f"Package {normalized} already cached, skipping task creation")
return None
# Create new task
task = PyPICacheTask(
package_name=normalized,
version_constraint=version_constraint,
parent_task_id=parent_task_id,
depth=depth,
triggered_by_artifact=triggered_by_artifact,
max_attempts=settings.PYPI_CACHE_MAX_ATTEMPTS,
)
db.add(task)
db.flush()
logger.info(f"Enqueued cache task for {normalized} (depth={depth})")
return task
def get_cache_status(db: Session) -> dict:
"""
Get summary of cache task queue status.
Returns:
Dict with counts by status.
"""
from sqlalchemy import func
stats = (
db.query(PyPICacheTask.status, func.count(PyPICacheTask.id))
.group_by(PyPICacheTask.status)
.all()
)
return {
"pending": next((s[1] for s in stats if s[0] == "pending"), 0),
"in_progress": next((s[1] for s in stats if s[0] == "in_progress"), 0),
"completed": next((s[1] for s in stats if s[0] == "completed"), 0),
"failed": next((s[1] for s in stats if s[0] == "failed"), 0),
}
def get_failed_tasks(db: Session, limit: int = 50) -> List[dict]:
"""
Get list of failed tasks for debugging.
Args:
db: Database session.
limit: Maximum number of tasks to return.
Returns:
List of failed task info dicts.
"""
tasks = (
db.query(PyPICacheTask)
.filter(PyPICacheTask.status == "failed")
.order_by(PyPICacheTask.completed_at.desc())
.limit(limit)
.all()
)
return [
{
"id": str(task.id),
"package": task.package_name,
"error": task.error_message,
"attempts": task.attempts,
"depth": task.depth,
"failed_at": task.completed_at.isoformat() if task.completed_at else None,
}
for task in tasks
]
def get_active_tasks(db: Session, limit: int = 50) -> List[dict]:
"""
Get list of currently active (in_progress) tasks.
Args:
db: Database session.
limit: Maximum number of tasks to return.
Returns:
List of active task info dicts.
"""
tasks = (
db.query(PyPICacheTask)
.filter(PyPICacheTask.status == "in_progress")
.order_by(PyPICacheTask.started_at.desc())
.limit(limit)
.all()
)
return [
{
"id": str(task.id),
"package": task.package_name,
"version_constraint": task.version_constraint,
"depth": task.depth,
"attempts": task.attempts,
"started_at": task.started_at.isoformat() if task.started_at else None,
}
for task in tasks
]
def retry_failed_task(db: Session, package_name: str) -> Optional[PyPICacheTask]:
"""
Reset a failed task to retry.
Args:
db: Database session.
package_name: The package name to retry.
Returns:
The reset task, or None if not found.
"""
normalized = re.sub(r"[-_.]+", "-", package_name).lower()
task = (
db.query(PyPICacheTask)
.filter(
PyPICacheTask.package_name == normalized,
PyPICacheTask.status == "failed",
)
.first()
)
if not task:
return None
task.status = "pending"
task.attempts = 0
task.next_retry_at = None
task.error_message = None
task.started_at = None
task.completed_at = None
db.commit()
logger.info(f"Reset failed task for retry: {normalized}")
return task
def retry_all_failed_tasks(db: Session) -> int:
"""
Reset all failed tasks to retry.
Args:
db: Database session.
Returns:
Number of tasks reset.
"""
count = (
db.query(PyPICacheTask)
.filter(PyPICacheTask.status == "failed")
.update(
{
"status": "pending",
"attempts": 0,
"next_retry_at": None,
"error_message": None,
"started_at": None,
"completed_at": None,
}
)
)
db.commit()
logger.info(f"Reset {count} failed tasks for retry")
return count