Add upstream caching infrastructure and refactor CI pipeline

This commit is contained in:
Mondo Diaz
2026-01-29 11:55:15 -06:00
parent c92895ffe9
commit 1d51c856b0
24 changed files with 7285 additions and 117 deletions

586
backend/app/upstream.py Normal file
View File

@@ -0,0 +1,586 @@
"""
HTTP client for fetching artifacts from upstream sources.
Provides streaming downloads with SHA256 computation, authentication support,
and automatic source matching based on URL prefixes.
"""
from __future__ import annotations
import hashlib
import logging
import tempfile
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import BinaryIO, Optional, TYPE_CHECKING
from urllib.parse import urlparse
import httpx
if TYPE_CHECKING:
from .models import CacheSettings, UpstreamSource
logger = logging.getLogger(__name__)
class UpstreamError(Exception):
"""Base exception for upstream client errors."""
pass
class UpstreamConnectionError(UpstreamError):
"""Connection to upstream failed (network error, DNS, etc.)."""
pass
class UpstreamTimeoutError(UpstreamError):
"""Request to upstream timed out."""
pass
class UpstreamHTTPError(UpstreamError):
"""Upstream returned an HTTP error response."""
def __init__(self, message: str, status_code: int, response_headers: dict = None):
super().__init__(message)
self.status_code = status_code
self.response_headers = response_headers or {}
class UpstreamSSLError(UpstreamError):
"""SSL/TLS error when connecting to upstream."""
pass
class AirGapError(UpstreamError):
"""Request blocked due to air-gap mode."""
pass
class FileSizeExceededError(UpstreamError):
"""File size exceeds the maximum allowed."""
def __init__(self, message: str, content_length: int, max_size: int):
super().__init__(message)
self.content_length = content_length
self.max_size = max_size
class SourceNotFoundError(UpstreamError):
"""No matching upstream source found for URL."""
pass
class SourceDisabledError(UpstreamError):
"""The matching upstream source is disabled."""
pass
@dataclass
class FetchResult:
"""Result of fetching an artifact from upstream."""
content: BinaryIO # File-like object with content
sha256: str # SHA256 hash of content
size: int # Size in bytes
content_type: Optional[str] # Content-Type header
response_headers: dict # All response headers for provenance
source_name: Optional[str] = None # Name of matched upstream source
temp_path: Optional[Path] = None # Path to temp file (for cleanup)
def close(self):
"""Close and clean up resources."""
if self.content:
try:
self.content.close()
except Exception:
pass
if self.temp_path and self.temp_path.exists():
try:
self.temp_path.unlink()
except Exception:
pass
@dataclass
class UpstreamClientConfig:
"""Configuration for the upstream client."""
connect_timeout: float = 30.0 # Connection timeout in seconds
read_timeout: float = 300.0 # Read timeout in seconds (5 minutes for large files)
max_retries: int = 3 # Maximum number of retry attempts
retry_backoff_base: float = 1.0 # Base delay for exponential backoff
retry_backoff_max: float = 30.0 # Maximum delay between retries
follow_redirects: bool = True # Whether to follow redirects
max_redirects: int = 5 # Maximum number of redirects to follow
max_file_size: Optional[int] = None # Maximum file size (None = unlimited)
verify_ssl: bool = True # Verify SSL certificates
user_agent: str = "Orchard-UpstreamClient/1.0"
class UpstreamClient:
"""
HTTP client for fetching artifacts from upstream sources.
Supports streaming downloads, multiple authentication methods,
automatic source matching, and air-gap mode enforcement.
"""
def __init__(
self,
sources: list[UpstreamSource] = None,
cache_settings: CacheSettings = None,
config: UpstreamClientConfig = None,
):
"""
Initialize the upstream client.
Args:
sources: List of upstream sources for URL matching and auth.
Should be sorted by priority (lowest first).
cache_settings: Global cache settings including air-gap mode.
config: Client configuration options.
"""
self.sources = sources or []
self.cache_settings = cache_settings
self.config = config or UpstreamClientConfig()
# Sort sources by priority (lower = higher priority)
self.sources = sorted(self.sources, key=lambda s: s.priority)
def _get_allow_public_internet(self) -> bool:
"""Get the allow_public_internet setting."""
if self.cache_settings is None:
return True # Default to allowing if no settings provided
return self.cache_settings.allow_public_internet
def _match_source(self, url: str) -> Optional[UpstreamSource]:
"""
Find the upstream source that matches the given URL.
Matches by URL prefix, returns the highest priority match.
Args:
url: The URL to match.
Returns:
The matching UpstreamSource or None if no match.
"""
for source in self.sources:
# Check if URL starts with source URL (prefix match)
if url.startswith(source.url.rstrip("/")):
return source
return None
def _build_auth_headers(self, source: UpstreamSource) -> dict:
"""
Build authentication headers for the given source.
Args:
source: The upstream source with auth configuration.
Returns:
Dictionary of headers to add to the request.
"""
headers = {}
if source.auth_type == "none":
pass
elif source.auth_type == "basic":
# httpx handles basic auth via auth parameter, but we can also
# do it manually if needed. We'll use the auth parameter instead.
pass
elif source.auth_type == "bearer":
password = source.get_password()
if password:
headers["Authorization"] = f"Bearer {password}"
elif source.auth_type == "api_key":
# API key auth uses custom headers
custom_headers = source.get_headers()
if custom_headers:
headers.update(custom_headers)
return headers
def _get_basic_auth(self, source: UpstreamSource) -> Optional[tuple[str, str]]:
"""
Get basic auth credentials if applicable.
Args:
source: The upstream source.
Returns:
Tuple of (username, password) or None.
"""
if source.auth_type == "basic" and source.username:
password = source.get_password() or ""
return (source.username, password)
return None
def _should_retry(self, error: Exception, attempt: int) -> bool:
"""
Determine if a request should be retried.
Args:
error: The exception that occurred.
attempt: Current attempt number (0-indexed).
Returns:
True if the request should be retried.
"""
if attempt >= self.config.max_retries - 1:
return False
# Retry on connection errors and timeouts
if isinstance(error, (httpx.ConnectError, httpx.ConnectTimeout)):
return True
# Retry on read timeouts
if isinstance(error, httpx.ReadTimeout):
return True
# Retry on certain HTTP errors (502, 503, 504)
if isinstance(error, httpx.HTTPStatusError):
return error.response.status_code in (502, 503, 504)
return False
def _calculate_backoff(self, attempt: int) -> float:
"""
Calculate backoff delay for retry.
Uses exponential backoff with jitter.
Args:
attempt: Current attempt number (0-indexed).
Returns:
Delay in seconds.
"""
import random
delay = self.config.retry_backoff_base * (2**attempt)
# Add jitter (±25%)
delay *= 0.75 + random.random() * 0.5
return min(delay, self.config.retry_backoff_max)
def fetch(self, url: str, expected_hash: Optional[str] = None) -> FetchResult:
"""
Fetch an artifact from the given URL.
Streams the response to a temp file while computing the SHA256 hash.
Handles authentication, retries, and error cases.
Args:
url: The URL to fetch.
expected_hash: Optional expected SHA256 hash for verification.
Returns:
FetchResult with content, hash, size, and headers.
Raises:
AirGapError: If air-gap mode blocks the request.
SourceDisabledError: If the matching source is disabled.
UpstreamConnectionError: On connection failures.
UpstreamTimeoutError: On timeout.
UpstreamHTTPError: On HTTP error responses.
UpstreamSSLError: On SSL/TLS errors.
FileSizeExceededError: If Content-Length exceeds max_file_size.
"""
start_time = time.time()
# Match URL to source
source = self._match_source(url)
# Check air-gap mode
allow_public = self._get_allow_public_internet()
if not allow_public:
if source is None:
raise AirGapError(
f"Air-gap mode enabled: URL does not match any configured upstream source: {url}"
)
if source.is_public:
raise AirGapError(
f"Air-gap mode enabled: Cannot fetch from public source '{source.name}'"
)
# Check if source is enabled (if we have a match)
if source is not None and not source.enabled:
raise SourceDisabledError(
f"Upstream source '{source.name}' is disabled"
)
source_name = source.name if source else None
logger.info(
f"Fetching URL: {url} (source: {source_name or 'none'})"
)
# Build request parameters
headers = {"User-Agent": self.config.user_agent}
auth = None
if source:
headers.update(self._build_auth_headers(source))
auth = self._get_basic_auth(source)
timeout = httpx.Timeout(
connect=self.config.connect_timeout,
read=self.config.read_timeout,
write=30.0,
pool=10.0,
)
# Attempt fetch with retries
last_error = None
for attempt in range(self.config.max_retries):
try:
return self._do_fetch(
url=url,
headers=headers,
auth=auth,
timeout=timeout,
source_name=source_name,
start_time=start_time,
expected_hash=expected_hash,
)
except (
httpx.ConnectError,
httpx.ConnectTimeout,
httpx.ReadTimeout,
httpx.HTTPStatusError,
) as e:
last_error = e
if self._should_retry(e, attempt):
delay = self._calculate_backoff(attempt)
logger.warning(
f"Fetch failed (attempt {attempt + 1}/{self.config.max_retries}), "
f"retrying in {delay:.1f}s: {e}"
)
time.sleep(delay)
else:
break
# Convert final error to our exception types
self._raise_upstream_error(last_error, url)
def _do_fetch(
self,
url: str,
headers: dict,
auth: Optional[tuple[str, str]],
timeout: httpx.Timeout,
source_name: Optional[str],
start_time: float,
expected_hash: Optional[str] = None,
) -> FetchResult:
"""
Perform the actual fetch operation.
Args:
url: URL to fetch.
headers: Request headers.
auth: Basic auth credentials or None.
timeout: Request timeout configuration.
source_name: Name of matched source for logging.
start_time: Request start time for timing.
expected_hash: Optional expected hash for verification.
Returns:
FetchResult with content and metadata.
"""
with httpx.Client(
timeout=timeout,
follow_redirects=self.config.follow_redirects,
max_redirects=self.config.max_redirects,
verify=self.config.verify_ssl,
) as client:
with client.stream("GET", url, headers=headers, auth=auth) as response:
# Check for HTTP errors
response.raise_for_status()
# Check Content-Length against max size
content_length = response.headers.get("content-length")
if content_length:
content_length = int(content_length)
if (
self.config.max_file_size
and content_length > self.config.max_file_size
):
raise FileSizeExceededError(
f"File size {content_length} exceeds maximum {self.config.max_file_size}",
content_length,
self.config.max_file_size,
)
# Stream to temp file while computing hash
hasher = hashlib.sha256()
size = 0
# Create temp file
temp_file = tempfile.NamedTemporaryFile(
delete=False, prefix="orchard_upstream_"
)
temp_path = Path(temp_file.name)
try:
for chunk in response.iter_bytes(chunk_size=65536):
temp_file.write(chunk)
hasher.update(chunk)
size += len(chunk)
# Check size while streaming if max_file_size is set
if self.config.max_file_size and size > self.config.max_file_size:
temp_file.close()
temp_path.unlink()
raise FileSizeExceededError(
f"Downloaded size {size} exceeds maximum {self.config.max_file_size}",
size,
self.config.max_file_size,
)
temp_file.close()
sha256 = hasher.hexdigest()
# Verify hash if expected
if expected_hash and sha256 != expected_hash.lower():
temp_path.unlink()
raise UpstreamError(
f"Hash mismatch: expected {expected_hash}, got {sha256}"
)
# Capture response headers
response_headers = dict(response.headers)
# Get content type
content_type = response.headers.get("content-type")
elapsed = time.time() - start_time
logger.info(
f"Fetched {url}: {size} bytes, sha256={sha256[:12]}..., "
f"source={source_name}, time={elapsed:.2f}s"
)
# Return file handle positioned at start
content = open(temp_path, "rb")
return FetchResult(
content=content,
sha256=sha256,
size=size,
content_type=content_type,
response_headers=response_headers,
source_name=source_name,
temp_path=temp_path,
)
except Exception:
# Clean up on error
try:
temp_file.close()
except Exception:
pass
if temp_path.exists():
temp_path.unlink()
raise
def _raise_upstream_error(self, error: Exception, url: str):
"""
Convert httpx exception to appropriate UpstreamError.
Args:
error: The httpx exception.
url: The URL that was being fetched.
Raises:
Appropriate UpstreamError subclass.
"""
if error is None:
raise UpstreamError(f"Unknown error fetching {url}")
if isinstance(error, httpx.ConnectError):
raise UpstreamConnectionError(
f"Failed to connect to upstream: {error}"
) from error
if isinstance(error, (httpx.ConnectTimeout, httpx.ReadTimeout)):
raise UpstreamTimeoutError(
f"Request timed out: {error}"
) from error
if isinstance(error, httpx.HTTPStatusError):
raise UpstreamHTTPError(
f"HTTP {error.response.status_code}: {error}",
error.response.status_code,
dict(error.response.headers),
) from error
# Check for SSL errors in the error chain
if "ssl" in str(error).lower() or "certificate" in str(error).lower():
raise UpstreamSSLError(f"SSL/TLS error: {error}") from error
raise UpstreamError(f"Error fetching {url}: {error}") from error
def test_connection(self, source: UpstreamSource) -> tuple[bool, Optional[str], Optional[int]]:
"""
Test connectivity to an upstream source.
Performs a HEAD request to the source URL to verify connectivity
and authentication.
Args:
source: The upstream source to test.
Returns:
Tuple of (success, error_message, status_code).
"""
headers = {"User-Agent": self.config.user_agent}
headers.update(self._build_auth_headers(source))
auth = self._get_basic_auth(source)
timeout = httpx.Timeout(
connect=self.config.connect_timeout,
read=30.0,
write=30.0,
pool=10.0,
)
try:
with httpx.Client(
timeout=timeout,
verify=self.config.verify_ssl,
) as client:
response = client.head(
source.url,
headers=headers,
auth=auth,
follow_redirects=True,
)
# Consider 2xx and 3xx as success, also 405 (Method Not Allowed)
# since some servers don't support HEAD
if response.status_code < 400 or response.status_code == 405:
return (True, None, response.status_code)
else:
return (
False,
f"HTTP {response.status_code}",
response.status_code,
)
except httpx.ConnectError as e:
return (False, f"Connection failed: {e}", None)
except httpx.ConnectTimeout as e:
return (False, f"Connection timed out: {e}", None)
except httpx.ReadTimeout as e:
return (False, f"Read timed out: {e}", None)
except Exception as e:
return (False, f"Error: {e}", None)