Add security fixes and code cleanup for PyPI cache

- Add require_admin authentication to cache management endpoints
- Add limit validation (1-500) on failed tasks query
- Add thread lock for worker pool thread safety
- Fix exception handling with separate recovery DB session
- Remove obsolete design doc
This commit is contained in:
Mondo Diaz
2026-02-02 11:37:25 -06:00
parent ba708332a5
commit 97b39d000b
3 changed files with 58 additions and 288 deletions

View File

@@ -26,7 +26,8 @@ from .models import PyPICacheTask, Package, Project, Tag
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Module-level worker pool state # Module-level worker pool state with lock for thread safety
_worker_lock = threading.Lock()
_cache_worker_pool: Optional[ThreadPoolExecutor] = None _cache_worker_pool: Optional[ThreadPoolExecutor] = None
_cache_worker_running: bool = False _cache_worker_running: bool = False
_dispatcher_thread: Optional[threading.Thread] = None _dispatcher_thread: Optional[threading.Thread] = None
@@ -41,6 +42,7 @@ def init_cache_worker_pool(max_workers: Optional[int] = None):
""" """
global _cache_worker_pool, _cache_worker_running, _dispatcher_thread global _cache_worker_pool, _cache_worker_running, _dispatcher_thread
with _worker_lock:
if _cache_worker_pool is not None: if _cache_worker_pool is not None:
logger.warning("Cache worker pool already initialized") logger.warning("Cache worker pool already initialized")
return return
@@ -73,17 +75,20 @@ def shutdown_cache_worker_pool(wait: bool = True, timeout: float = 30.0):
""" """
global _cache_worker_pool, _cache_worker_running, _dispatcher_thread global _cache_worker_pool, _cache_worker_running, _dispatcher_thread
with _worker_lock:
if _cache_worker_pool is None: if _cache_worker_pool is None:
return return
logger.info("Shutting down PyPI cache worker pool...") logger.info("Shutting down PyPI cache worker pool...")
_cache_worker_running = False _cache_worker_running = False
# Wait for dispatcher to stop # Wait for dispatcher to stop (outside lock to avoid deadlock)
if _dispatcher_thread and _dispatcher_thread.is_alive(): if _dispatcher_thread and _dispatcher_thread.is_alive():
_dispatcher_thread.join(timeout=5.0) _dispatcher_thread.join(timeout=5.0)
with _worker_lock:
# Shutdown thread pool # Shutdown thread pool
if _cache_worker_pool:
_cache_worker_pool.shutdown(wait=wait, cancel_futures=not wait) _cache_worker_pool.shutdown(wait=wait, cancel_futures=not wait)
_cache_worker_pool = None _cache_worker_pool = None
_dispatcher_thread = None _dispatcher_thread = None
@@ -198,13 +203,14 @@ def _process_cache_task(task_id: UUID):
except Exception as e: except Exception as e:
logger.exception(f"Error processing cache task {task_id}") logger.exception(f"Error processing cache task {task_id}")
db = SessionLocal() # Get fresh session after exception # Use a fresh session for error handling to avoid transaction issues
recovery_db = SessionLocal()
try: try:
task = db.query(PyPICacheTask).filter(PyPICacheTask.id == task_id).first() task = recovery_db.query(PyPICacheTask).filter(PyPICacheTask.id == task_id).first()
if task: if task:
_handle_task_failure(db, task, str(e)) _handle_task_failure(recovery_db, task, str(e))
finally: finally:
db.close() recovery_db.close()
finally: finally:
db.close() db.close()

View File

@@ -15,12 +15,13 @@ from typing import Optional, List, Tuple
from urllib.parse import urljoin, urlparse, quote, unquote from urllib.parse import urljoin, urlparse, quote, unquote
import httpx import httpx
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Request, Response from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Query, Request, Response
from fastapi.responses import StreamingResponse, HTMLResponse from fastapi.responses import StreamingResponse, HTMLResponse
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from .auth import require_admin
from .database import get_db from .database import get_db
from .models import UpstreamSource, CachedUrl, Artifact, Project, Package, Tag, PackageVersion, ArtifactDependency from .models import User, UpstreamSource, CachedUrl, Artifact, Project, Package, Tag, PackageVersion, ArtifactDependency
from .storage import S3Storage, get_storage from .storage import S3Storage, get_storage
from .config import get_env_upstream_sources from .config import get_env_upstream_sources
from .pypi_cache_worker import ( from .pypi_cache_worker import (
@@ -814,25 +815,32 @@ async def pypi_download_file(
@router.get("/cache/status") @router.get("/cache/status")
async def pypi_cache_status(db: Session = Depends(get_db)): async def pypi_cache_status(
db: Session = Depends(get_db),
_current_user: User = Depends(require_admin),
):
""" """
Get summary of the PyPI cache task queue. Get summary of the PyPI cache task queue.
Returns counts of tasks by status (pending, in_progress, completed, failed). Returns counts of tasks by status (pending, in_progress, completed, failed).
Requires admin privileges.
""" """
return get_cache_status(db) return get_cache_status(db)
@router.get("/cache/failed") @router.get("/cache/failed")
async def pypi_cache_failed( async def pypi_cache_failed(
limit: int = 50, limit: int = Query(default=50, ge=1, le=500),
db: Session = Depends(get_db), db: Session = Depends(get_db),
_current_user: User = Depends(require_admin),
): ):
""" """
Get list of failed cache tasks for debugging. Get list of failed cache tasks for debugging.
Args: Args:
limit: Maximum number of tasks to return (default 50). limit: Maximum number of tasks to return (default 50, max 500).
Requires admin privileges.
""" """
return get_failed_tasks(db, limit=limit) return get_failed_tasks(db, limit=limit)
@@ -841,12 +849,15 @@ async def pypi_cache_failed(
async def pypi_cache_retry( async def pypi_cache_retry(
package_name: str, package_name: str,
db: Session = Depends(get_db), db: Session = Depends(get_db),
_current_user: User = Depends(require_admin),
): ):
""" """
Reset a failed cache task to retry. Reset a failed cache task to retry.
Args: Args:
package_name: The package name to retry. package_name: The package name to retry.
Requires admin privileges.
""" """
task = retry_failed_task(db, package_name) task = retry_failed_task(db, package_name)
if not task: if not task:
@@ -858,11 +869,15 @@ async def pypi_cache_retry(
@router.post("/cache/retry-all") @router.post("/cache/retry-all")
async def pypi_cache_retry_all(db: Session = Depends(get_db)): async def pypi_cache_retry_all(
db: Session = Depends(get_db),
_current_user: User = Depends(require_admin),
):
""" """
Reset all failed cache tasks to retry. Reset all failed cache tasks to retry.
Returns the count of tasks that were reset. Returns the count of tasks that were reset.
Requires admin privileges.
""" """
count = retry_all_failed_tasks(db) count = retry_all_failed_tasks(db)
return {"message": f"Queued {count} tasks for retry", "count": count} return {"message": f"Queued {count} tasks for retry", "count": count}

View File

@@ -1,251 +0,0 @@
# PyPI Cache Robustness Design
**Date:** 2026-02-02
**Status:** Approved
**Branch:** fix/pypi-proxy-timeout
## Problem
The current PyPI proxy proactive caching has reliability issues:
- Unbounded thread spawning for each dependency
- Silent failures (logged but not tracked or retried)
- No visibility into cache completeness
- Deps-of-deps often missing due to untracked failures
## Solution
Database-backed task queue with managed worker pool, automatic retries, and visibility API.
---
## Data Model
New table `pypi_cache_tasks`:
```sql
CREATE TABLE pypi_cache_tasks (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
-- What to cache
package_name VARCHAR(255) NOT NULL,
version_constraint VARCHAR(255),
-- Origin tracking
parent_task_id UUID REFERENCES pypi_cache_tasks(id) ON DELETE SET NULL,
depth INTEGER NOT NULL DEFAULT 0,
triggered_by_artifact VARCHAR(64),
-- Status
status VARCHAR(20) NOT NULL DEFAULT 'pending',
attempts INTEGER NOT NULL DEFAULT 0,
max_attempts INTEGER NOT NULL DEFAULT 3,
-- Results
cached_artifact_id VARCHAR(64),
error_message TEXT,
-- Timing
created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(),
started_at TIMESTAMP WITH TIME ZONE,
completed_at TIMESTAMP WITH TIME ZONE,
next_retry_at TIMESTAMP WITH TIME ZONE
);
-- Indexes
CREATE INDEX idx_pypi_cache_tasks_status_retry ON pypi_cache_tasks(status, next_retry_at);
CREATE INDEX idx_pypi_cache_tasks_package_status ON pypi_cache_tasks(package_name, status);
CREATE INDEX idx_pypi_cache_tasks_parent ON pypi_cache_tasks(parent_task_id);
-- Constraints
ALTER TABLE pypi_cache_tasks ADD CONSTRAINT check_task_status
CHECK (status IN ('pending', 'in_progress', 'completed', 'failed'));
```
---
## Worker Architecture
### Thread Pool (5 workers default)
```python
_cache_worker_pool: ThreadPoolExecutor = None
_cache_worker_running: bool = False
def init_cache_worker_pool(max_workers: int = 5):
global _cache_worker_pool, _cache_worker_running
_cache_worker_pool = ThreadPoolExecutor(max_workers=max_workers, thread_name_prefix="pypi-cache-")
_cache_worker_running = True
threading.Thread(target=_cache_dispatcher_loop, daemon=True).start()
```
### Dispatcher Loop
- Polls DB every 2 seconds when idle
- Fetches batch of 10 ready tasks
- Marks tasks in_progress before submitting to pool
- Orders by depth (shallow first) then FIFO
### Task Processing
1. Dedup check - skip if package already cached
2. Dedup check - skip if pending/in_progress task exists for same package
3. Depth check - fail if >= 10 levels deep
4. Fetch package index page
5. Download best matching file (prefer wheels)
6. Store artifact, extract dependencies
7. Queue child tasks for each dependency
8. Mark completed or handle failure
---
## Retry Logic
Exponential backoff with 3 attempts:
| Attempt | Backoff |
|---------|---------|
| 1 fails | 30 seconds |
| 2 fails | 60 seconds |
| 3 fails | Permanent failure |
```python
backoff_seconds = 30 * (2 ** (attempts - 1))
task.next_retry_at = datetime.utcnow() + timedelta(seconds=backoff_seconds)
```
---
## API Endpoints
| Endpoint | Method | Purpose |
|----------|--------|---------|
| `/pypi/cache/status` | GET | Queue health summary |
| `/pypi/cache/failed` | GET | List failed tasks with errors |
| `/pypi/cache/retry/{package}` | POST | Retry single failed package |
| `/pypi/cache/retry-all` | POST | Retry all failed packages |
### Response Examples
**GET /pypi/cache/status**
```json
{
"pending": 12,
"in_progress": 3,
"completed": 847,
"failed": 5
}
```
**GET /pypi/cache/failed**
```json
[
{
"package": "some-obscure-pkg",
"error": "Timeout connecting to upstream",
"attempts": 3,
"failed_at": "2026-02-02T10:30:00Z"
}
]
```
---
## Integration Points
### Replace Thread Spawning (pypi_proxy.py)
```python
# OLD: _start_background_dependency_caching(base_url, unique_deps)
# NEW:
for dep_name, dep_version in unique_deps:
_enqueue_cache_task(
db,
package_name=dep_name,
version_constraint=dep_version,
parent_task_id=None,
depth=0,
triggered_by_artifact=sha256,
)
```
### App Startup (main.py)
```python
@app.on_event("startup")
async def startup():
init_cache_worker_pool(max_workers=settings.PYPI_CACHE_WORKERS)
@app.on_event("shutdown")
async def shutdown():
shutdown_cache_worker_pool()
```
### Configuration (config.py)
```python
PYPI_CACHE_WORKERS = int(os.getenv("ORCHARD_PYPI_CACHE_WORKERS", "5"))
PYPI_CACHE_MAX_DEPTH = int(os.getenv("ORCHARD_PYPI_CACHE_MAX_DEPTH", "10"))
PYPI_CACHE_MAX_ATTEMPTS = int(os.getenv("ORCHARD_PYPI_CACHE_MAX_ATTEMPTS", "3"))
```
---
## Files to Create/Modify
| File | Action |
|------|--------|
| `migrations/0XX_pypi_cache_tasks.sql` | Create - new table |
| `backend/app/models.py` | Modify - add PyPICacheTask model |
| `backend/app/pypi_cache_worker.py` | Create - worker pool + processing |
| `backend/app/pypi_proxy.py` | Modify - replace threads, add API |
| `backend/app/main.py` | Modify - init worker on startup |
| `backend/app/config.py` | Modify - add config variables |
| `backend/tests/test_pypi_cache_worker.py` | Create - unit tests |
| `backend/tests/integration/test_pypi_cache_api.py` | Create - API tests |
---
## Deduplication Strategy
### At Task Creation Time
```python
def _enqueue_cache_task(db, package_name, ...):
# Check for existing pending/in_progress task
existing_task = db.query(PyPICacheTask).filter(
PyPICacheTask.package_name == package_name,
PyPICacheTask.status.in_(["pending", "in_progress"])
).first()
if existing_task:
return existing_task
# Check if already cached
if _find_cached_package(db, package_name):
return None
# Create new task
...
```
### At Processing Time (safety check)
```python
def _process_cache_task(task_id):
# Double-check in case of race
if _find_cached_package(db, task.package_name):
_mark_task_completed(db, task, cached_artifact_id=existing.artifact_id)
return
```
---
## Success Criteria
- [ ] No unbounded thread creation
- [ ] All dependency caching attempts tracked in database
- [ ] Failed tasks automatically retry with backoff
- [ ] API provides visibility into queue status
- [ ] Manual retry capability for failed packages
- [ ] Existing pip install workflow unchanged (transparent)
- [ ] Tests cover worker, retry, and API functionality