Add security fixes and code cleanup for PyPI cache

- Add require_admin authentication to cache management endpoints
- Add limit validation (1-500) on failed tasks query
- Add thread lock for worker pool thread safety
- Fix exception handling with separate recovery DB session
- Remove obsolete design doc
This commit is contained in:
Mondo Diaz
2026-02-02 11:37:25 -06:00
parent ba708332a5
commit 97b39d000b
3 changed files with 58 additions and 288 deletions

View File

@@ -26,7 +26,8 @@ from .models import PyPICacheTask, Package, Project, Tag
logger = logging.getLogger(__name__)
# Module-level worker pool state
# Module-level worker pool state with lock for thread safety
_worker_lock = threading.Lock()
_cache_worker_pool: Optional[ThreadPoolExecutor] = None
_cache_worker_running: bool = False
_dispatcher_thread: Optional[threading.Thread] = None
@@ -41,26 +42,27 @@ def init_cache_worker_pool(max_workers: Optional[int] = None):
"""
global _cache_worker_pool, _cache_worker_running, _dispatcher_thread
if _cache_worker_pool is not None:
logger.warning("Cache worker pool already initialized")
return
with _worker_lock:
if _cache_worker_pool is not None:
logger.warning("Cache worker pool already initialized")
return
workers = max_workers or settings.PYPI_CACHE_WORKERS
_cache_worker_pool = ThreadPoolExecutor(
max_workers=workers,
thread_name_prefix="pypi-cache-",
)
_cache_worker_running = True
workers = max_workers or settings.PYPI_CACHE_WORKERS
_cache_worker_pool = ThreadPoolExecutor(
max_workers=workers,
thread_name_prefix="pypi-cache-",
)
_cache_worker_running = True
# Start the dispatcher thread
_dispatcher_thread = threading.Thread(
target=_cache_dispatcher_loop,
daemon=True,
name="pypi-cache-dispatcher",
)
_dispatcher_thread.start()
# Start the dispatcher thread
_dispatcher_thread = threading.Thread(
target=_cache_dispatcher_loop,
daemon=True,
name="pypi-cache-dispatcher",
)
_dispatcher_thread.start()
logger.info(f"PyPI cache worker pool initialized with {workers} workers")
logger.info(f"PyPI cache worker pool initialized with {workers} workers")
def shutdown_cache_worker_pool(wait: bool = True, timeout: float = 30.0):
@@ -73,20 +75,23 @@ def shutdown_cache_worker_pool(wait: bool = True, timeout: float = 30.0):
"""
global _cache_worker_pool, _cache_worker_running, _dispatcher_thread
if _cache_worker_pool is None:
return
with _worker_lock:
if _cache_worker_pool is None:
return
logger.info("Shutting down PyPI cache worker pool...")
_cache_worker_running = False
logger.info("Shutting down PyPI cache worker pool...")
_cache_worker_running = False
# Wait for dispatcher to stop
# Wait for dispatcher to stop (outside lock to avoid deadlock)
if _dispatcher_thread and _dispatcher_thread.is_alive():
_dispatcher_thread.join(timeout=5.0)
# Shutdown thread pool
_cache_worker_pool.shutdown(wait=wait, cancel_futures=not wait)
_cache_worker_pool = None
_dispatcher_thread = None
with _worker_lock:
# Shutdown thread pool
if _cache_worker_pool:
_cache_worker_pool.shutdown(wait=wait, cancel_futures=not wait)
_cache_worker_pool = None
_dispatcher_thread = None
logger.info("PyPI cache worker pool shut down")
@@ -198,13 +203,14 @@ def _process_cache_task(task_id: UUID):
except Exception as e:
logger.exception(f"Error processing cache task {task_id}")
db = SessionLocal() # Get fresh session after exception
# Use a fresh session for error handling to avoid transaction issues
recovery_db = SessionLocal()
try:
task = db.query(PyPICacheTask).filter(PyPICacheTask.id == task_id).first()
task = recovery_db.query(PyPICacheTask).filter(PyPICacheTask.id == task_id).first()
if task:
_handle_task_failure(db, task, str(e))
_handle_task_failure(recovery_db, task, str(e))
finally:
db.close()
recovery_db.close()
finally:
db.close()

View File

@@ -15,12 +15,13 @@ from typing import Optional, List, Tuple
from urllib.parse import urljoin, urlparse, quote, unquote
import httpx
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Request, Response
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Query, Request, Response
from fastapi.responses import StreamingResponse, HTMLResponse
from sqlalchemy.orm import Session
from .auth import require_admin
from .database import get_db
from .models import UpstreamSource, CachedUrl, Artifact, Project, Package, Tag, PackageVersion, ArtifactDependency
from .models import User, UpstreamSource, CachedUrl, Artifact, Project, Package, Tag, PackageVersion, ArtifactDependency
from .storage import S3Storage, get_storage
from .config import get_env_upstream_sources
from .pypi_cache_worker import (
@@ -814,25 +815,32 @@ async def pypi_download_file(
@router.get("/cache/status")
async def pypi_cache_status(db: Session = Depends(get_db)):
async def pypi_cache_status(
db: Session = Depends(get_db),
_current_user: User = Depends(require_admin),
):
"""
Get summary of the PyPI cache task queue.
Returns counts of tasks by status (pending, in_progress, completed, failed).
Requires admin privileges.
"""
return get_cache_status(db)
@router.get("/cache/failed")
async def pypi_cache_failed(
limit: int = 50,
limit: int = Query(default=50, ge=1, le=500),
db: Session = Depends(get_db),
_current_user: User = Depends(require_admin),
):
"""
Get list of failed cache tasks for debugging.
Args:
limit: Maximum number of tasks to return (default 50).
limit: Maximum number of tasks to return (default 50, max 500).
Requires admin privileges.
"""
return get_failed_tasks(db, limit=limit)
@@ -841,12 +849,15 @@ async def pypi_cache_failed(
async def pypi_cache_retry(
package_name: str,
db: Session = Depends(get_db),
_current_user: User = Depends(require_admin),
):
"""
Reset a failed cache task to retry.
Args:
package_name: The package name to retry.
Requires admin privileges.
"""
task = retry_failed_task(db, package_name)
if not task:
@@ -858,11 +869,15 @@ async def pypi_cache_retry(
@router.post("/cache/retry-all")
async def pypi_cache_retry_all(db: Session = Depends(get_db)):
async def pypi_cache_retry_all(
db: Session = Depends(get_db),
_current_user: User = Depends(require_admin),
):
"""
Reset all failed cache tasks to retry.
Returns the count of tasks that were reset.
Requires admin privileges.
"""
count = retry_all_failed_tasks(db)
return {"message": f"Queued {count} tasks for retry", "count": count}