Remove proactive PyPI dependency caching feature

The background task queue for proactively caching package dependencies was
causing server instability and unnecessary growth. The PyPI proxy now only
caches packages on-demand when users request them.

Removed:
- PyPI cache worker (background task queue and worker pool)
- PyPICacheTask model and related database schema
- Cache management API endpoints (/pypi/cache/*)
- Background Jobs admin dashboard
- Dependency extraction and queueing logic

Kept:
- On-demand package caching (still works when users request packages)
- Async httpx for non-blocking downloads (prevents health check failures)
- URL-based cache lookups for deduplication
This commit is contained in:
Mondo Diaz
2026-02-02 16:17:33 -06:00
parent 2136e1f0c5
commit 31edadf3ad
11 changed files with 4 additions and 2392 deletions

View File

@@ -15,7 +15,6 @@ from .pypi_proxy import router as pypi_router
from .seed import seed_database
from .auth import create_default_admin
from .rate_limit import limiter
from .pypi_cache_worker import init_cache_worker_pool, shutdown_cache_worker_pool
settings = get_settings()
logging.basicConfig(level=logging.INFO)
@@ -50,14 +49,8 @@ async def lifespan(app: FastAPI):
else:
logger.info(f"Running in {settings.env} mode - skipping seed data")
# Initialize PyPI cache worker pool
init_cache_worker_pool()
yield
# Shutdown: cleanup
shutdown_cache_worker_pool()
app = FastAPI(
title="Orchard",

View File

@@ -803,70 +803,3 @@ class CachedUrl(Base):
return hashlib.sha256(url.encode("utf-8")).hexdigest()
class PyPICacheTask(Base):
"""Task for caching a PyPI package and its dependencies.
Tracks the status of background caching operations with retry support.
Used by the PyPI proxy to ensure reliable dependency caching.
"""
__tablename__ = "pypi_cache_tasks"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
# What to cache
package_name = Column(String(255), nullable=False)
version_constraint = Column(String(255))
# Origin tracking
parent_task_id = Column(
UUID(as_uuid=True),
ForeignKey("pypi_cache_tasks.id", ondelete="SET NULL"),
)
depth = Column(Integer, nullable=False, default=0)
triggered_by_artifact = Column(
String(64),
ForeignKey("artifacts.id", ondelete="SET NULL"),
)
# Status
status = Column(String(20), nullable=False, default="pending")
attempts = Column(Integer, nullable=False, default=0)
max_attempts = Column(Integer, nullable=False, default=3)
# Results
cached_artifact_id = Column(
String(64),
ForeignKey("artifacts.id", ondelete="SET NULL"),
)
error_message = Column(Text)
# Timing
created_at = Column(DateTime(timezone=True), nullable=False, default=datetime.utcnow)
started_at = Column(DateTime(timezone=True))
completed_at = Column(DateTime(timezone=True))
next_retry_at = Column(DateTime(timezone=True))
# Relationships
parent_task = relationship(
"PyPICacheTask",
remote_side=[id],
backref="child_tasks",
)
__table_args__ = (
Index("idx_pypi_cache_tasks_status_retry", "status", "next_retry_at"),
Index("idx_pypi_cache_tasks_package_status", "package_name", "status"),
Index("idx_pypi_cache_tasks_parent", "parent_task_id"),
Index("idx_pypi_cache_tasks_triggered_by", "triggered_by_artifact"),
Index("idx_pypi_cache_tasks_cached_artifact", "cached_artifact_id"),
Index("idx_pypi_cache_tasks_depth_created", "depth", "created_at"),
CheckConstraint(
"status IN ('pending', 'in_progress', 'completed', 'failed')",
name="check_task_status",
),
CheckConstraint("depth >= 0", name="check_depth_non_negative"),
CheckConstraint("attempts >= 0", name="check_attempts_non_negative"),
)

View File

@@ -1,735 +0,0 @@
"""
PyPI cache worker module.
Manages a thread pool for background caching of PyPI packages and their dependencies.
Replaces unbounded thread spawning with a managed queue-based approach.
"""
import logging
import re
import threading
import time
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime, timedelta
from typing import List, Optional
from uuid import UUID
import httpx
from sqlalchemy import or_
from sqlalchemy.orm import Session
from .config import get_settings
settings = get_settings()
from .database import SessionLocal
from .models import PyPICacheTask, Package, Project, Tag
logger = logging.getLogger(__name__)
# Module-level worker pool state with lock for thread safety
_worker_lock = threading.Lock()
_cache_worker_pool: Optional[ThreadPoolExecutor] = None
_cache_worker_running: bool = False
_dispatcher_thread: Optional[threading.Thread] = None
def _recover_stale_tasks():
"""
Recover tasks stuck in 'in_progress' state from a previous crash.
Called on startup to reset tasks that were being processed when
the server crashed. Resets them to 'pending' so they can be retried.
"""
db = SessionLocal()
try:
# Find tasks that have been in_progress for more than 5 minutes
# These are likely from a crashed worker
stale_threshold = datetime.utcnow() - timedelta(minutes=5)
stale_count = (
db.query(PyPICacheTask)
.filter(
PyPICacheTask.status == "in_progress",
or_(
PyPICacheTask.started_at == None,
PyPICacheTask.started_at < stale_threshold,
),
)
.update(
{
"status": "pending",
"started_at": None,
}
)
)
db.commit()
if stale_count > 0:
logger.warning(f"Recovered {stale_count} stale in_progress tasks from previous crash")
except Exception as e:
logger.error(f"Error recovering stale tasks: {e}")
finally:
db.close()
def init_cache_worker_pool(max_workers: Optional[int] = None):
"""
Initialize the cache worker pool. Called on app startup.
Args:
max_workers: Number of concurrent workers. Defaults to PYPI_CACHE_WORKERS setting.
"""
global _cache_worker_pool, _cache_worker_running, _dispatcher_thread
with _worker_lock:
if _cache_worker_pool is not None:
logger.warning("Cache worker pool already initialized")
return
# Recover any stale tasks from previous crash before starting workers
_recover_stale_tasks()
workers = max_workers or settings.PYPI_CACHE_WORKERS
_cache_worker_pool = ThreadPoolExecutor(
max_workers=workers,
thread_name_prefix="pypi-cache-",
)
_cache_worker_running = True
# Start the dispatcher thread
_dispatcher_thread = threading.Thread(
target=_cache_dispatcher_loop,
daemon=True,
name="pypi-cache-dispatcher",
)
_dispatcher_thread.start()
logger.info(f"PyPI cache worker pool initialized with {workers} workers")
def shutdown_cache_worker_pool(wait: bool = True, timeout: float = 30.0):
"""
Shutdown the cache worker pool gracefully.
Args:
wait: Whether to wait for pending tasks to complete.
timeout: Maximum time to wait for shutdown.
"""
global _cache_worker_pool, _cache_worker_running, _dispatcher_thread
with _worker_lock:
if _cache_worker_pool is None:
return
logger.info("Shutting down PyPI cache worker pool...")
_cache_worker_running = False
# Wait for dispatcher to stop (outside lock to avoid deadlock)
if _dispatcher_thread and _dispatcher_thread.is_alive():
_dispatcher_thread.join(timeout=5.0)
with _worker_lock:
# Shutdown thread pool
if _cache_worker_pool:
_cache_worker_pool.shutdown(wait=wait, cancel_futures=not wait)
_cache_worker_pool = None
_dispatcher_thread = None
logger.info("PyPI cache worker pool shut down")
def _cache_dispatcher_loop():
"""
Main dispatcher loop: poll DB for pending tasks and submit to worker pool.
"""
logger.info("PyPI cache dispatcher started")
while _cache_worker_running:
try:
db = SessionLocal()
try:
tasks = _get_ready_tasks(db, limit=10)
for task in tasks:
# Mark in_progress before submitting
task.status = "in_progress"
task.started_at = datetime.utcnow()
db.commit()
# Submit to worker pool
_cache_worker_pool.submit(_process_cache_task, task.id)
# Sleep if no work (avoid busy loop)
if not tasks:
time.sleep(2.0)
else:
# Small delay between batches to avoid overwhelming
time.sleep(0.1)
finally:
db.close()
except Exception as e:
logger.error(f"PyPI cache dispatcher error: {e}")
time.sleep(5.0)
logger.info("PyPI cache dispatcher stopped")
def _get_ready_tasks(db: Session, limit: int = 10) -> List[PyPICacheTask]:
"""
Get tasks ready to process.
Returns pending tasks that are either new or ready for retry.
Orders by depth (shallow first) then creation time (FIFO).
"""
now = datetime.utcnow()
return (
db.query(PyPICacheTask)
.filter(
PyPICacheTask.status == "pending",
or_(
PyPICacheTask.next_retry_at == None, # New tasks
PyPICacheTask.next_retry_at <= now, # Retry tasks ready
),
)
.order_by(
PyPICacheTask.depth.asc(), # Prefer shallow deps first
PyPICacheTask.created_at.asc(), # FIFO within same depth
)
.limit(limit)
.all()
)
def _process_cache_task(task_id: UUID):
"""
Process a single cache task. Called by worker pool.
Args:
task_id: The ID of the task to process.
"""
db = SessionLocal()
try:
task = db.query(PyPICacheTask).filter(PyPICacheTask.id == task_id).first()
if not task:
logger.warning(f"PyPI cache task {task_id} not found")
return
logger.info(
f"Processing cache task: {task.package_name} "
f"(depth={task.depth}, attempt={task.attempts + 1})"
)
# Check if already cached by another task (dedup)
existing_artifact = _find_cached_package(db, task.package_name)
if existing_artifact:
logger.info(f"Package {task.package_name} already cached, skipping")
_mark_task_completed(db, task, cached_artifact_id=existing_artifact)
return
# Check depth limit
max_depth = settings.PYPI_CACHE_MAX_DEPTH
if task.depth >= max_depth:
_mark_task_failed(db, task, f"Max depth {max_depth} exceeded")
return
# Do the actual caching - pass depth so nested deps are queued at depth+1
result = _fetch_and_cache_package(task.package_name, task.version_constraint, depth=task.depth)
if result["success"]:
_mark_task_completed(db, task, cached_artifact_id=result.get("artifact_id"))
logger.info(f"Successfully cached {task.package_name}")
else:
_handle_task_failure(db, task, result["error"])
except Exception as e:
logger.exception(f"Error processing cache task {task_id}")
# Use a fresh session for error handling to avoid transaction issues
recovery_db = SessionLocal()
try:
task = recovery_db.query(PyPICacheTask).filter(PyPICacheTask.id == task_id).first()
if task:
_handle_task_failure(recovery_db, task, str(e))
finally:
recovery_db.close()
finally:
db.close()
def _find_cached_package(db: Session, package_name: str) -> Optional[str]:
"""
Check if a package is already cached.
Args:
db: Database session.
package_name: Normalized package name.
Returns:
Artifact ID if cached, None otherwise.
"""
# Normalize package name (PEP 503)
normalized = re.sub(r"[-_.]+", "-", package_name).lower()
# Check if _pypi project has this package with at least one tag
system_project = db.query(Project).filter(Project.name == "_pypi").first()
if not system_project:
return None
package = (
db.query(Package)
.filter(
Package.project_id == system_project.id,
Package.name == normalized,
)
.first()
)
if not package:
return None
# Check if package has any tags (cached files)
tag = db.query(Tag).filter(Tag.package_id == package.id).first()
if tag:
return tag.artifact_id
return None
def _fetch_and_cache_package(
package_name: str,
version_constraint: Optional[str] = None,
depth: int = 0,
) -> dict:
"""
Fetch and cache a PyPI package by making requests through our own proxy.
Args:
package_name: The package name to cache.
version_constraint: Optional version constraint (currently not used for selection).
depth: Current recursion depth for dependency tracking.
Returns:
Dict with "success" bool, "artifact_id" on success, "error" on failure.
"""
# Normalize package name (PEP 503)
normalized_name = re.sub(r"[-_.]+", "-", package_name).lower()
# Build the URL to our own proxy
# Use localhost since we're making internal requests
base_url = f"http://localhost:{settings.PORT}"
try:
with httpx.Client(timeout=60.0, follow_redirects=True) as client:
# Step 1: Get the simple index page
simple_url = f"{base_url}/pypi/simple/{normalized_name}/"
logger.debug(f"Fetching index: {simple_url}")
response = client.get(simple_url)
if response.status_code == 404:
return {"success": False, "error": f"Package {package_name} not found on upstream"}
if response.status_code != 200:
return {"success": False, "error": f"Failed to get index: HTTP {response.status_code}"}
# Step 2: Parse HTML to find downloadable files
html = response.text
# Create pattern that matches both normalized (hyphens) and original (underscores)
name_pattern = re.sub(r"[-_]+", "[-_]+", normalized_name)
# Look for wheel files first (preferred)
wheel_pattern = rf'href="([^"]*{name_pattern}[^"]*\.whl[^"]*)"'
matches = re.findall(wheel_pattern, html, re.IGNORECASE)
if not matches:
# Fall back to sdist
sdist_pattern = rf'href="([^"]*{name_pattern}[^"]*\.tar\.gz[^"]*)"'
matches = re.findall(sdist_pattern, html, re.IGNORECASE)
if not matches:
logger.warning(
f"No downloadable files found for {package_name}. "
f"Pattern: {wheel_pattern}, HTML preview: {html[:500]}"
)
return {"success": False, "error": "No downloadable files found"}
# Get the last match (usually latest version)
download_url = matches[-1]
# Make URL absolute if needed
if download_url.startswith("/"):
download_url = f"{base_url}{download_url}"
elif not download_url.startswith("http"):
download_url = f"{base_url}/pypi/simple/{normalized_name}/{download_url}"
# Add cache-depth query parameter to track recursion depth
# The proxy will queue dependencies at depth+1
separator = "&" if "?" in download_url else "?"
download_url = f"{download_url}{separator}cache-depth={depth}"
# Step 3: Download the file through our proxy (this caches it)
logger.debug(f"Downloading: {download_url}")
response = client.get(download_url)
if response.status_code != 200:
return {"success": False, "error": f"Download failed: HTTP {response.status_code}"}
# Get artifact ID from response header
artifact_id = response.headers.get("X-Checksum-SHA256")
return {"success": True, "artifact_id": artifact_id}
except httpx.TimeoutException as e:
return {"success": False, "error": f"Timeout: {e}"}
except httpx.ConnectError as e:
return {"success": False, "error": f"Connection failed: {e}"}
except Exception as e:
return {"success": False, "error": str(e)}
# Alias for backward compatibility and clearer naming
_fetch_and_cache_package_with_depth = _fetch_and_cache_package
def _mark_task_completed(
db: Session,
task: PyPICacheTask,
cached_artifact_id: Optional[str] = None,
):
"""Mark a task as completed."""
task.status = "completed"
task.completed_at = datetime.utcnow()
task.cached_artifact_id = cached_artifact_id
task.error_message = None
db.commit()
def _mark_task_failed(db: Session, task: PyPICacheTask, error: str):
"""Mark a task as permanently failed."""
task.status = "failed"
task.completed_at = datetime.utcnow()
task.error_message = error[:1000] if error else None
db.commit()
logger.warning(f"PyPI cache task failed permanently: {task.package_name} - {error}")
def _handle_task_failure(db: Session, task: PyPICacheTask, error: str):
"""
Handle a failed cache attempt with exponential backoff.
Args:
db: Database session.
task: The failed task.
error: Error message.
"""
task.attempts += 1
task.error_message = error[:1000] if error else None
max_attempts = task.max_attempts or settings.PYPI_CACHE_MAX_ATTEMPTS
if task.attempts >= max_attempts:
# Give up after max attempts
task.status = "failed"
task.completed_at = datetime.utcnow()
logger.warning(
f"PyPI cache task failed permanently: {task.package_name} - {error} "
f"(after {task.attempts} attempts)"
)
else:
# Schedule retry with exponential backoff
# Attempt 1 failed → retry in 30s
# Attempt 2 failed → retry in 60s
# Attempt 3 failed → permanent failure (if max_attempts=3)
backoff_seconds = 30 * (2 ** (task.attempts - 1))
task.status = "pending"
task.next_retry_at = datetime.utcnow() + timedelta(seconds=backoff_seconds)
logger.info(
f"PyPI cache task will retry: {task.package_name} in {backoff_seconds}s "
f"(attempt {task.attempts}/{max_attempts})"
)
db.commit()
def enqueue_cache_task(
db: Session,
package_name: str,
version_constraint: Optional[str] = None,
parent_task_id: Optional[UUID] = None,
depth: int = 0,
triggered_by_artifact: Optional[str] = None,
) -> Optional[PyPICacheTask]:
"""
Enqueue a package for caching.
Performs deduplication: won't create a task if one already exists
for the same package in pending/in_progress state, or if the package
is already cached.
Args:
db: Database session.
package_name: The package name to cache.
version_constraint: Optional version constraint.
parent_task_id: Parent task that spawned this one.
depth: Recursion depth.
triggered_by_artifact: Artifact that declared this dependency.
Returns:
The created or existing task, or None if already cached.
"""
# Normalize package name (PEP 503)
normalized = re.sub(r"[-_.]+", "-", package_name).lower()
# Check for existing pending/in_progress task
existing_task = (
db.query(PyPICacheTask)
.filter(
PyPICacheTask.package_name == normalized,
PyPICacheTask.status.in_(["pending", "in_progress"]),
)
.first()
)
if existing_task:
logger.debug(f"Task already exists for {normalized}: {existing_task.id}")
return existing_task
# Check if already cached
if _find_cached_package(db, normalized):
logger.debug(f"Package {normalized} already cached, skipping task creation")
return None
# Create new task
task = PyPICacheTask(
package_name=normalized,
version_constraint=version_constraint,
parent_task_id=parent_task_id,
depth=depth,
triggered_by_artifact=triggered_by_artifact,
max_attempts=settings.PYPI_CACHE_MAX_ATTEMPTS,
)
db.add(task)
db.flush()
logger.info(f"Enqueued cache task for {normalized} (depth={depth})")
return task
def get_cache_status(db: Session) -> dict:
"""
Get summary of cache task queue status.
Returns:
Dict with counts by status.
"""
from sqlalchemy import func
stats = (
db.query(PyPICacheTask.status, func.count(PyPICacheTask.id))
.group_by(PyPICacheTask.status)
.all()
)
return {
"pending": next((s[1] for s in stats if s[0] == "pending"), 0),
"in_progress": next((s[1] for s in stats if s[0] == "in_progress"), 0),
"completed": next((s[1] for s in stats if s[0] == "completed"), 0),
"failed": next((s[1] for s in stats if s[0] == "failed"), 0),
}
def get_failed_tasks(db: Session, limit: int = 50) -> List[dict]:
"""
Get list of failed tasks for debugging.
Args:
db: Database session.
limit: Maximum number of tasks to return.
Returns:
List of failed task info dicts.
"""
tasks = (
db.query(PyPICacheTask)
.filter(PyPICacheTask.status == "failed")
.order_by(PyPICacheTask.completed_at.desc())
.limit(limit)
.all()
)
return [
{
"id": str(task.id),
"package": task.package_name,
"error": task.error_message,
"attempts": task.attempts,
"depth": task.depth,
"failed_at": task.completed_at.isoformat() if task.completed_at else None,
}
for task in tasks
]
def get_active_tasks(db: Session, limit: int = 50) -> List[dict]:
"""
Get list of currently active (in_progress) tasks.
Args:
db: Database session.
limit: Maximum number of tasks to return.
Returns:
List of active task info dicts.
"""
tasks = (
db.query(PyPICacheTask)
.filter(PyPICacheTask.status == "in_progress")
.order_by(PyPICacheTask.started_at.desc())
.limit(limit)
.all()
)
return [
{
"id": str(task.id),
"package": task.package_name,
"version_constraint": task.version_constraint,
"depth": task.depth,
"attempts": task.attempts,
"started_at": task.started_at.isoformat() if task.started_at else None,
}
for task in tasks
]
def get_recent_activity(db: Session, limit: int = 20) -> List[dict]:
"""
Get recent task completions and failures for activity feed.
Args:
db: Database session.
limit: Maximum number of items to return.
Returns:
List of recent activity items sorted by time descending.
"""
# Get recently completed and failed tasks
tasks = (
db.query(PyPICacheTask)
.filter(PyPICacheTask.status.in_(["completed", "failed"]))
.filter(PyPICacheTask.completed_at != None)
.order_by(PyPICacheTask.completed_at.desc())
.limit(limit)
.all()
)
return [
{
"id": str(task.id),
"package": task.package_name,
"status": task.status,
"type": "pypi",
"error": task.error_message if task.status == "failed" else None,
"completed_at": task.completed_at.isoformat() if task.completed_at else None,
}
for task in tasks
]
def retry_failed_task(db: Session, package_name: str) -> Optional[PyPICacheTask]:
"""
Reset a failed task to retry.
Args:
db: Database session.
package_name: The package name to retry.
Returns:
The reset task, or None if not found.
"""
normalized = re.sub(r"[-_.]+", "-", package_name).lower()
task = (
db.query(PyPICacheTask)
.filter(
PyPICacheTask.package_name == normalized,
PyPICacheTask.status == "failed",
)
.first()
)
if not task:
return None
task.status = "pending"
task.attempts = 0
task.next_retry_at = None
task.error_message = None
task.started_at = None
task.completed_at = None
db.commit()
logger.info(f"Reset failed task for retry: {normalized}")
return task
def retry_all_failed_tasks(db: Session) -> int:
"""
Reset all failed tasks to retry.
Args:
db: Database session.
Returns:
Number of tasks reset.
"""
count = (
db.query(PyPICacheTask)
.filter(PyPICacheTask.status == "failed")
.update(
{
"status": "pending",
"attempts": 0,
"next_retry_at": None,
"error_message": None,
"started_at": None,
"completed_at": None,
}
)
)
db.commit()
logger.info(f"Reset {count} failed tasks for retry")
return count
def cancel_cache_task(db: Session, package_name: str) -> Optional[PyPICacheTask]:
"""
Cancel an in-progress or pending cache task.
Args:
db: Database session.
package_name: The package name to cancel.
Returns:
The cancelled task, or None if not found.
"""
normalized = re.sub(r"[-_.]+", "-", package_name).lower()
task = (
db.query(PyPICacheTask)
.filter(
PyPICacheTask.package_name == normalized,
PyPICacheTask.status.in_(["pending", "in_progress"]),
)
.first()
)
if not task:
return None
task.status = "failed"
task.completed_at = datetime.utcnow()
task.error_message = "Cancelled by admin"
db.commit()
logger.info(f"Cancelled cache task: {normalized}")
return task

View File

@@ -9,172 +9,25 @@ import hashlib
import logging
import os
import re
import tarfile
import tempfile
import zipfile
from io import BytesIO
from typing import Optional, List, Tuple
from typing import Optional
from urllib.parse import urljoin, urlparse, quote, unquote
import httpx
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Query, Request, Response
from fastapi import APIRouter, Depends, HTTPException, Request, Response
from fastapi.responses import StreamingResponse, HTMLResponse
from sqlalchemy.orm import Session
from .auth import require_admin
from .database import get_db
from .models import User, UpstreamSource, CachedUrl, Artifact, Project, Package, Tag, PackageVersion, ArtifactDependency
from .models import UpstreamSource, CachedUrl, Artifact, Project, Package, Tag, PackageVersion
from .storage import S3Storage, get_storage
from .config import get_env_upstream_sources
from .pypi_cache_worker import (
enqueue_cache_task,
get_cache_status,
get_failed_tasks,
get_active_tasks,
get_recent_activity,
retry_failed_task,
retry_all_failed_tasks,
cancel_cache_task,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/pypi", tags=["pypi-proxy"])
def _parse_requires_dist(requires_dist: str) -> Tuple[str, Optional[str]]:
"""Parse a Requires-Dist line into (package_name, version_constraint).
Examples:
"requests (>=2.25.0)" -> ("requests", ">=2.25.0")
"typing-extensions; python_version < '3.8'" -> ("typing-extensions", None)
"numpy>=1.21.0" -> ("numpy", ">=1.21.0")
"certifi" -> ("certifi", None)
Returns:
Tuple of (normalized_package_name, version_constraint or None)
"""
# Remove any environment markers (after semicolon)
if ';' in requires_dist:
requires_dist = requires_dist.split(';')[0].strip()
# Match patterns like "package (>=1.0)" or "package>=1.0" or "package"
# Pattern breakdown: package name, optional whitespace, optional version in parens or directly
match = re.match(
r'^([a-zA-Z0-9][-a-zA-Z0-9._]*)\s*(?:\(([^)]+)\)|([<>=!~][^\s;]+))?',
requires_dist.strip()
)
if not match:
return None, None
package_name = match.group(1)
# Version can be in parentheses (group 2) or directly after name (group 3)
version_constraint = match.group(2) or match.group(3)
# Normalize package name (PEP 503)
normalized_name = re.sub(r'[-_.]+', '-', package_name).lower()
# Clean up version constraint
if version_constraint:
version_constraint = version_constraint.strip()
return normalized_name, version_constraint
def _extract_requires_from_metadata(metadata_content: str) -> List[Tuple[str, Optional[str]]]:
"""Extract all Requires-Dist entries from METADATA/PKG-INFO content.
Args:
metadata_content: The content of a METADATA or PKG-INFO file
Returns:
List of (package_name, version_constraint) tuples
"""
dependencies = []
for line in metadata_content.split('\n'):
if line.startswith('Requires-Dist:'):
# Extract the value after "Requires-Dist:"
value = line[len('Requires-Dist:'):].strip()
pkg_name, version = _parse_requires_dist(value)
if pkg_name:
dependencies.append((pkg_name, version))
return dependencies
def _extract_metadata_from_wheel(content: bytes) -> Optional[str]:
"""Extract METADATA file content from a wheel (zip) file.
Wheel files have structure: {package}-{version}.dist-info/METADATA
Args:
content: The wheel file content as bytes
Returns:
METADATA file content as string, or None if not found
"""
try:
with zipfile.ZipFile(BytesIO(content)) as zf:
# Find the .dist-info directory
for name in zf.namelist():
if name.endswith('.dist-info/METADATA'):
return zf.read(name).decode('utf-8', errors='replace')
except Exception as e:
logger.warning(f"Failed to extract metadata from wheel: {e}")
return None
def _extract_metadata_from_sdist(content: bytes, filename: str) -> Optional[str]:
"""Extract PKG-INFO file content from a source distribution (.tar.gz).
Source distributions have structure: {package}-{version}/PKG-INFO
Args:
content: The tarball content as bytes
filename: The original filename (used to determine package name)
Returns:
PKG-INFO file content as string, or None if not found
"""
try:
with tarfile.open(fileobj=BytesIO(content), mode='r:gz') as tf:
# Find PKG-INFO in the root directory of the archive
for member in tf.getmembers():
if member.name.endswith('/PKG-INFO') and member.name.count('/') == 1:
f = tf.extractfile(member)
if f:
return f.read().decode('utf-8', errors='replace')
except Exception as e:
logger.warning(f"Failed to extract metadata from sdist {filename}: {e}")
return None
def _extract_dependencies(content: bytes, filename: str) -> List[Tuple[str, Optional[str]]]:
"""Extract dependencies from a PyPI package file.
Supports wheel (.whl) and source distribution (.tar.gz) formats.
Args:
content: The package file content as bytes
filename: The original filename
Returns:
List of (package_name, version_constraint) tuples
"""
metadata = None
if filename.endswith('.whl'):
metadata = _extract_metadata_from_wheel(content)
elif filename.endswith('.tar.gz'):
metadata = _extract_metadata_from_sdist(content, filename)
if metadata:
return _extract_requires_from_metadata(metadata)
return []
# Timeout configuration for proxy requests
PROXY_CONNECT_TIMEOUT = 30.0
PROXY_READ_TIMEOUT = 60.0
@@ -521,7 +374,6 @@ async def pypi_download_file(
package_name: str,
filename: str,
upstream: Optional[str] = None,
cache_depth: int = Query(default=0, ge=0, le=100, alias="cache-depth"),
db: Session = Depends(get_db),
storage: S3Storage = Depends(get_storage),
):
@@ -532,7 +384,6 @@ async def pypi_download_file(
package_name: The package name
filename: The filename to download
upstream: URL-encoded upstream URL to fetch from
cache_depth: Current cache recursion depth (used by cache worker for nested deps)
"""
if not upstream:
raise HTTPException(
@@ -656,7 +507,7 @@ async def pypi_download_file(
sha256 = result.sha256
size = result.size
# Read content for metadata extraction and response
# Read content for response
with open(tmp_path, 'rb') as f:
content = f.read()
@@ -766,50 +617,6 @@ async def pypi_download_file(
)
db.add(cached_url_record)
# Extract and store dependencies
dependencies = _extract_dependencies(content, filename)
unique_deps = []
if dependencies:
# Deduplicate dependencies by package name (keep first occurrence)
seen_packages = set()
for dep_name, dep_version in dependencies:
if dep_name not in seen_packages:
seen_packages.add(dep_name)
unique_deps.append((dep_name, dep_version))
logger.info(f"PyPI proxy: extracted {len(unique_deps)} dependencies from {filename} (deduped from {len(dependencies)})")
for dep_name, dep_version in unique_deps:
# Check if this dependency already exists for this artifact
existing_dep = db.query(ArtifactDependency).filter(
ArtifactDependency.artifact_id == sha256,
ArtifactDependency.dependency_project == "_pypi",
ArtifactDependency.dependency_package == dep_name,
).first()
if not existing_dep:
dep = ArtifactDependency(
artifact_id=sha256,
dependency_project="_pypi",
dependency_package=dep_name,
version_constraint=dep_version if dep_version else "*",
)
db.add(dep)
# Proactively cache dependencies via task queue
# Dependencies are queued at cache_depth + 1 to track recursion
if unique_deps:
next_depth = cache_depth + 1
for dep_name, dep_version in unique_deps:
enqueue_cache_task(
db,
package_name=dep_name,
version_constraint=dep_version,
parent_task_id=None, # Top-level, triggered by user download
depth=next_depth,
triggered_by_artifact=sha256,
)
logger.info(f"PyPI proxy: queued {len(unique_deps)} dependencies for caching (depth={next_depth})")
db.commit()
# Return the file
@@ -833,119 +640,3 @@ async def pypi_download_file(
except Exception as e:
logger.exception(f"PyPI proxy: error downloading {filename}")
raise HTTPException(status_code=500, detail=str(e))
# =============================================================================
# Cache Status and Management Endpoints
# =============================================================================
@router.get("/cache/status")
async def pypi_cache_status(
db: Session = Depends(get_db),
_current_user: User = Depends(require_admin),
):
"""
Get summary of the PyPI cache task queue.
Returns counts of tasks by status (pending, in_progress, completed, failed).
Requires admin privileges.
"""
return get_cache_status(db)
@router.get("/cache/failed")
async def pypi_cache_failed(
limit: int = Query(default=50, ge=1, le=500),
db: Session = Depends(get_db),
_current_user: User = Depends(require_admin),
):
"""
Get list of failed cache tasks for debugging.
Args:
limit: Maximum number of tasks to return (default 50, max 500).
Requires admin privileges.
"""
return get_failed_tasks(db, limit=limit)
@router.get("/cache/active")
async def pypi_cache_active(
limit: int = Query(default=50, ge=1, le=500),
db: Session = Depends(get_db),
_current_user: User = Depends(require_admin),
):
"""
Get list of currently active (in_progress) cache tasks.
Shows what the cache workers are currently processing.
Args:
limit: Maximum number of tasks to return (default 50, max 500).
Requires admin privileges.
"""
return get_active_tasks(db, limit=limit)
@router.post("/cache/retry/{package_name}")
async def pypi_cache_retry(
package_name: str,
db: Session = Depends(get_db),
_current_user: User = Depends(require_admin),
):
"""
Reset a failed cache task to retry.
Args:
package_name: The package name to retry.
Requires admin privileges.
"""
task = retry_failed_task(db, package_name)
if not task:
raise HTTPException(
status_code=404,
detail=f"No failed cache task found for package '{package_name}'"
)
return {"message": f"Retry queued for {task.package_name}", "task_id": str(task.id)}
@router.post("/cache/retry-all")
async def pypi_cache_retry_all(
db: Session = Depends(get_db),
_current_user: User = Depends(require_admin),
):
"""
Reset all failed cache tasks to retry.
Returns the count of tasks that were reset.
Requires admin privileges.
"""
count = retry_all_failed_tasks(db)
return {"message": f"Queued {count} tasks for retry", "count": count}
@router.post("/cache/cancel/{package_name}")
async def pypi_cache_cancel(
package_name: str,
db: Session = Depends(get_db),
_current_user: User = Depends(require_admin),
):
"""
Cancel an in-progress or pending cache task.
Args:
package_name: The package name to cancel.
Requires admin privileges.
"""
task = cancel_cache_task(db, package_name)
if not task:
raise HTTPException(
status_code=404,
detail=f"No active cache task found for package '{package_name}'"
)
return {"message": f"Cancelled task for {task.package_name}", "task_id": str(task.id)}

View File

@@ -1,364 +0,0 @@
"""Tests for PyPI cache worker module."""
import os
import pytest
import re
from datetime import datetime, timedelta
from unittest.mock import MagicMock, patch
from uuid import uuid4
import httpx
def get_base_url():
"""Get the base URL for the Orchard server from environment."""
return os.environ.get("ORCHARD_TEST_URL", "http://localhost:8080")
class TestPyPICacheTaskModel:
"""Tests for PyPICacheTask model."""
def test_model_creation(self):
"""Test that PyPICacheTask model can be instantiated with explicit values."""
from app.models import PyPICacheTask
task = PyPICacheTask(
package_name="requests",
version_constraint=">=2.25.0",
depth=0,
status="pending",
attempts=0,
max_attempts=3,
)
assert task.package_name == "requests"
assert task.version_constraint == ">=2.25.0"
assert task.depth == 0
assert task.status == "pending"
assert task.attempts == 0
assert task.max_attempts == 3
def test_model_fields_exist(self):
"""Test that PyPICacheTask has all expected fields."""
from app.models import PyPICacheTask
# Create with minimal required field
task = PyPICacheTask(package_name="urllib3")
# Verify all expected attributes exist (SQLAlchemy defaults apply on flush)
assert hasattr(task, "status")
assert hasattr(task, "depth")
assert hasattr(task, "attempts")
assert hasattr(task, "max_attempts")
assert hasattr(task, "version_constraint")
assert hasattr(task, "parent_task_id")
assert hasattr(task, "triggered_by_artifact")
class TestEnqueueCacheTask:
"""Tests for enqueue_cache_task function."""
def test_normalize_package_name(self):
"""Test that package names are normalized per PEP 503."""
# Test the normalization pattern used in the worker
test_cases = [
("Requests", "requests"),
("typing_extensions", "typing-extensions"),
("some.package", "some-package"),
("UPPER_CASE", "upper-case"),
("mixed-Case_name", "mixed-case-name"),
]
for input_name, expected in test_cases:
normalized = re.sub(r"[-_.]+", "-", input_name).lower()
assert normalized == expected, f"Failed for {input_name}"
class TestCacheWorkerFunctions:
"""Tests for cache worker helper functions."""
def test_exponential_backoff_calculation(self):
"""Test that exponential backoff is calculated correctly."""
# The formula is: 30 * (2 ** (attempts - 1))
# Attempt 1 failed → 30s
# Attempt 2 failed → 60s
# Attempt 3 failed → 120s
def calc_backoff(attempts):
return 30 * (2 ** (attempts - 1))
assert calc_backoff(1) == 30
assert calc_backoff(2) == 60
assert calc_backoff(3) == 120
class TestPyPICacheAPIEndpoints:
"""Integration tests for PyPI cache API endpoints."""
@pytest.mark.integration
def test_cache_status_endpoint(self):
"""Test GET /pypi/cache/status returns queue statistics."""
with httpx.Client(base_url=get_base_url(), timeout=30.0) as client:
response = client.get("/pypi/cache/status")
assert response.status_code == 200
data = response.json()
assert "pending" in data
assert "in_progress" in data
assert "completed" in data
assert "failed" in data
# All values should be non-negative integers
assert isinstance(data["pending"], int)
assert isinstance(data["in_progress"], int)
assert isinstance(data["completed"], int)
assert isinstance(data["failed"], int)
assert data["pending"] >= 0
assert data["in_progress"] >= 0
assert data["completed"] >= 0
assert data["failed"] >= 0
@pytest.mark.integration
def test_cache_failed_endpoint(self):
"""Test GET /pypi/cache/failed returns list of failed tasks."""
with httpx.Client(base_url=get_base_url(), timeout=30.0) as client:
response = client.get("/pypi/cache/failed")
assert response.status_code == 200
data = response.json()
assert isinstance(data, list)
# If there are failed tasks, verify structure
if data:
task = data[0]
assert "id" in task
assert "package" in task
assert "error" in task
assert "attempts" in task
assert "depth" in task
@pytest.mark.integration
def test_cache_failed_with_limit(self):
"""Test GET /pypi/cache/failed respects limit parameter."""
with httpx.Client(base_url=get_base_url(), timeout=30.0) as client:
response = client.get("/pypi/cache/failed?limit=5")
assert response.status_code == 200
data = response.json()
assert isinstance(data, list)
assert len(data) <= 5
@pytest.mark.integration
def test_cache_retry_nonexistent_package(self):
"""Test POST /pypi/cache/retry/{package} returns 404 for unknown package."""
with httpx.Client(base_url=get_base_url(), timeout=30.0) as client:
# Use a random package name that definitely doesn't exist
response = client.post(f"/pypi/cache/retry/nonexistent-package-{uuid4().hex[:8]}")
assert response.status_code == 404
# Check for "no failed" or "not found" in error message
detail = response.json()["detail"].lower()
assert "no failed" in detail or "not found" in detail
@pytest.mark.integration
def test_cache_retry_all_endpoint(self):
"""Test POST /pypi/cache/retry-all returns success."""
with httpx.Client(base_url=get_base_url(), timeout=30.0) as client:
response = client.post("/pypi/cache/retry-all")
assert response.status_code == 200
data = response.json()
assert "count" in data
assert "message" in data
assert isinstance(data["count"], int)
assert data["count"] >= 0
class TestCacheTaskDeduplication:
"""Tests for cache task deduplication logic."""
def test_find_cached_package_returns_none_for_uncached(self):
"""Test that _find_cached_package returns None for uncached packages."""
# This is a unit test pattern - mock the database
from unittest.mock import MagicMock
mock_db = MagicMock()
mock_db.query.return_value.filter.return_value.first.return_value = None
from app.pypi_cache_worker import _find_cached_package
result = _find_cached_package(mock_db, "nonexistent-package")
assert result is None
class TestCacheWorkerConfiguration:
"""Tests for cache worker configuration."""
def test_config_settings_exist(self):
"""Test that PyPI cache config settings are available."""
from app.config import get_settings
settings = get_settings()
# Check that settings exist and have reasonable defaults
assert hasattr(settings, "pypi_cache_workers")
assert hasattr(settings, "pypi_cache_max_depth")
assert hasattr(settings, "pypi_cache_max_attempts")
# Check aliases work
assert settings.PYPI_CACHE_WORKERS == settings.pypi_cache_workers
assert settings.PYPI_CACHE_MAX_DEPTH == settings.pypi_cache_max_depth
assert settings.PYPI_CACHE_MAX_ATTEMPTS == settings.pypi_cache_max_attempts
def test_config_default_values(self):
"""Test that PyPI cache config has sensible defaults."""
from app.config import get_settings
settings = get_settings()
# These are the defaults from our implementation
assert settings.pypi_cache_workers == 5
assert settings.pypi_cache_max_depth == 10
assert settings.pypi_cache_max_attempts == 3
class TestFetchAndCachePackage:
"""Tests for _fetch_and_cache_package function."""
def test_result_structure_success(self):
"""Test that success result has correct structure."""
# Mock a successful result
result = {"success": True, "artifact_id": "abc123"}
assert result["success"] is True
assert "artifact_id" in result
def test_result_structure_failure(self):
"""Test that failure result has correct structure."""
# Mock a failure result
result = {"success": False, "error": "Package not found"}
assert result["success"] is False
assert "error" in result
class TestWorkerPoolLifecycle:
"""Tests for worker pool initialization and shutdown."""
def test_init_shutdown_cycle(self):
"""Test that worker pool can be initialized and shut down cleanly."""
from app.pypi_cache_worker import (
init_cache_worker_pool,
shutdown_cache_worker_pool,
_cache_worker_pool,
_cache_worker_running,
)
# Note: We can't fully test this in isolation because the module
# has global state and may conflict with the running server.
# These tests verify the function signatures work.
# The pool should be initialized by main.py on startup
# We just verify the functions are callable
assert callable(init_cache_worker_pool)
assert callable(shutdown_cache_worker_pool)
class TestNestedDependencyDepthTracking:
"""Tests for nested dependency depth tracking.
When the cache worker downloads a package, its dependencies should be
queued with depth = current_task_depth + 1, not depth = 0.
"""
def test_enqueue_with_depth_increments_for_nested_deps(self):
"""Test that enqueue_cache_task properly tracks depth for nested dependencies.
When a task at depth=2 discovers a new dependency, that dependency
should be queued at depth=3.
"""
from unittest.mock import MagicMock, patch
from app.pypi_cache_worker import enqueue_cache_task
mock_db = MagicMock()
# No existing task for this package
mock_db.query.return_value.filter.return_value.first.return_value = None
# Mock _find_cached_package to return None (not cached)
with patch('app.pypi_cache_worker._find_cached_package', return_value=None):
task = enqueue_cache_task(
mock_db,
package_name="nested-dep",
version_constraint=">=1.0",
parent_task_id=None,
depth=3, # Parent task was at depth 2, so this dep is at depth 3
triggered_by_artifact="abc123",
)
# Verify db.add was called
mock_db.add.assert_called_once()
# Get the task that was added
added_task = mock_db.add.call_args[0][0]
# The task should have the correct depth
assert added_task.depth == 3, f"Expected depth=3, got depth={added_task.depth}"
assert added_task.package_name == "nested-dep"
def test_proxy_download_accepts_cache_depth_param(self):
"""Test that proxy download endpoint accepts cache-depth query parameter.
The cache worker should pass its current depth via query param so the proxy
can queue dependencies at the correct depth.
"""
# Verify that pypi_download_file has a cache_depth parameter
import inspect
from app.pypi_proxy import pypi_download_file
sig = inspect.signature(pypi_download_file)
params = list(sig.parameters.keys())
# The endpoint should accept a cache_depth parameter
assert 'cache_depth' in params, \
f"pypi_download_file should accept cache_depth parameter. Got params: {params}"
def test_worker_sends_depth_in_url_when_fetching(self):
"""Test that _fetch_and_cache_package includes depth in download URL.
When the worker fetches a package, it should include its current depth
in the URL query params so nested dependencies get queued at depth+1.
"""
from unittest.mock import patch, MagicMock
import httpx
# We need to verify that the httpx.Client.get call includes the depth in URL
with patch('app.pypi_cache_worker.httpx.Client') as mock_client_class:
mock_client = MagicMock()
mock_client_class.return_value.__enter__ = MagicMock(return_value=mock_client)
mock_client_class.return_value.__exit__ = MagicMock(return_value=False)
# Mock successful responses
mock_response_index = MagicMock()
mock_response_index.status_code = 200
mock_response_index.text = '''
<html><body>
<a href="/pypi/simple/test-pkg/test_pkg-1.0.0-py3-none-any.whl?upstream=http%3A%2F%2Fexample.com">test_pkg-1.0.0-py3-none-any.whl</a>
</body></html>
'''
mock_response_download = MagicMock()
mock_response_download.status_code = 200
mock_response_download.headers = {"X-Checksum-SHA256": "abc123"}
mock_client.get.side_effect = [mock_response_index, mock_response_download]
from app.pypi_cache_worker import _fetch_and_cache_package_with_depth
# This function should exist and accept depth parameter
result = _fetch_and_cache_package_with_depth("test-pkg", None, depth=2)
# Verify the download request included the cache-depth query param
download_call = mock_client.get.call_args_list[1]
download_url = download_call[0][0] # First positional arg is URL
assert "cache-depth=2" in download_url, \
f"Expected cache-depth=2 in URL, got: {download_url}"