""" HTTP client for fetching artifacts from upstream sources. Provides streaming downloads with SHA256 computation, authentication support, and automatic source matching based on URL prefixes. """ from __future__ import annotations import hashlib import logging import tempfile import time from dataclasses import dataclass, field from pathlib import Path from typing import BinaryIO, Optional, TYPE_CHECKING from urllib.parse import urlparse import httpx if TYPE_CHECKING: from .models import CacheSettings, UpstreamSource logger = logging.getLogger(__name__) class UpstreamError(Exception): """Base exception for upstream client errors.""" pass class UpstreamConnectionError(UpstreamError): """Connection to upstream failed (network error, DNS, etc.).""" pass class UpstreamTimeoutError(UpstreamError): """Request to upstream timed out.""" pass class UpstreamHTTPError(UpstreamError): """Upstream returned an HTTP error response.""" def __init__(self, message: str, status_code: int, response_headers: dict = None): super().__init__(message) self.status_code = status_code self.response_headers = response_headers or {} class UpstreamSSLError(UpstreamError): """SSL/TLS error when connecting to upstream.""" pass class FileSizeExceededError(UpstreamError): """File size exceeds the maximum allowed.""" def __init__(self, message: str, content_length: int, max_size: int): super().__init__(message) self.content_length = content_length self.max_size = max_size class SourceNotFoundError(UpstreamError): """No matching upstream source found for URL.""" pass class SourceDisabledError(UpstreamError): """The matching upstream source is disabled.""" pass @dataclass class FetchResult: """Result of fetching an artifact from upstream.""" content: BinaryIO # File-like object with content sha256: str # SHA256 hash of content size: int # Size in bytes content_type: Optional[str] # Content-Type header response_headers: dict # All response headers for provenance source_name: Optional[str] = None # Name of matched upstream source temp_path: Optional[Path] = None # Path to temp file (for cleanup) def close(self): """Close and clean up resources.""" if self.content: try: self.content.close() except Exception: pass if self.temp_path and self.temp_path.exists(): try: self.temp_path.unlink() except Exception: pass @dataclass class UpstreamClientConfig: """Configuration for the upstream client.""" connect_timeout: float = 30.0 # Connection timeout in seconds read_timeout: float = 300.0 # Read timeout in seconds (5 minutes for large files) max_retries: int = 3 # Maximum number of retry attempts retry_backoff_base: float = 1.0 # Base delay for exponential backoff retry_backoff_max: float = 30.0 # Maximum delay between retries follow_redirects: bool = True # Whether to follow redirects max_redirects: int = 5 # Maximum number of redirects to follow max_file_size: Optional[int] = None # Maximum file size (None = unlimited) verify_ssl: bool = True # Verify SSL certificates user_agent: str = "Orchard-UpstreamClient/1.0" class UpstreamClient: """ HTTP client for fetching artifacts from upstream sources. Supports streaming downloads, multiple authentication methods, automatic source matching, and air-gap mode enforcement. """ def __init__( self, sources: list[UpstreamSource] = None, cache_settings: CacheSettings = None, config: UpstreamClientConfig = None, ): """ Initialize the upstream client. Args: sources: List of upstream sources for URL matching and auth. Should be sorted by priority (lowest first). cache_settings: Global cache settings including air-gap mode. config: Client configuration options. """ self.sources = sources or [] self.cache_settings = cache_settings self.config = config or UpstreamClientConfig() # Sort sources by priority (lower = higher priority) self.sources = sorted(self.sources, key=lambda s: s.priority) def _match_source(self, url: str) -> Optional[UpstreamSource]: """ Find the upstream source that matches the given URL. Matches by URL prefix, returns the highest priority match. Args: url: The URL to match. Returns: The matching UpstreamSource or None if no match. """ for source in self.sources: # Check if URL starts with source URL (prefix match) if url.startswith(source.url.rstrip("/")): return source return None def _build_auth_headers(self, source: UpstreamSource) -> dict: """ Build authentication headers for the given source. Args: source: The upstream source with auth configuration. Returns: Dictionary of headers to add to the request. """ headers = {} if source.auth_type == "none": pass elif source.auth_type == "basic": # httpx handles basic auth via auth parameter, but we can also # do it manually if needed. We'll use the auth parameter instead. pass elif source.auth_type == "bearer": password = source.get_password() if password: headers["Authorization"] = f"Bearer {password}" elif source.auth_type == "api_key": # API key auth uses custom headers custom_headers = source.get_headers() if custom_headers: headers.update(custom_headers) return headers def _get_basic_auth(self, source: UpstreamSource) -> Optional[tuple[str, str]]: """ Get basic auth credentials if applicable. Args: source: The upstream source. Returns: Tuple of (username, password) or None. """ if source.auth_type == "basic" and source.username: password = source.get_password() or "" return (source.username, password) return None def _should_retry(self, error: Exception, attempt: int) -> bool: """ Determine if a request should be retried. Args: error: The exception that occurred. attempt: Current attempt number (0-indexed). Returns: True if the request should be retried. """ if attempt >= self.config.max_retries - 1: return False # Retry on connection errors and timeouts if isinstance(error, (httpx.ConnectError, httpx.ConnectTimeout)): return True # Retry on read timeouts if isinstance(error, httpx.ReadTimeout): return True # Retry on certain HTTP errors (502, 503, 504) if isinstance(error, httpx.HTTPStatusError): return error.response.status_code in (502, 503, 504) return False def _calculate_backoff(self, attempt: int) -> float: """ Calculate backoff delay for retry. Uses exponential backoff with jitter. Args: attempt: Current attempt number (0-indexed). Returns: Delay in seconds. """ import random delay = self.config.retry_backoff_base * (2**attempt) # Add jitter (±25%) delay *= 0.75 + random.random() * 0.5 return min(delay, self.config.retry_backoff_max) def fetch(self, url: str, expected_hash: Optional[str] = None) -> FetchResult: """ Fetch an artifact from the given URL. Streams the response to a temp file while computing the SHA256 hash. Handles authentication, retries, and error cases. Args: url: The URL to fetch. expected_hash: Optional expected SHA256 hash for verification. Returns: FetchResult with content, hash, size, and headers. Raises: SourceDisabledError: If the matching source is disabled. UpstreamConnectionError: On connection failures. UpstreamTimeoutError: On timeout. UpstreamHTTPError: On HTTP error responses. UpstreamSSLError: On SSL/TLS errors. FileSizeExceededError: If Content-Length exceeds max_file_size. """ start_time = time.time() # Match URL to source source = self._match_source(url) # Check if source is enabled (if we have a match) if source is not None and not source.enabled: raise SourceDisabledError( f"Upstream source '{source.name}' is disabled" ) source_name = source.name if source else None logger.info( f"Fetching URL: {url} (source: {source_name or 'none'})" ) # Build request parameters headers = {"User-Agent": self.config.user_agent} auth = None if source: headers.update(self._build_auth_headers(source)) auth = self._get_basic_auth(source) timeout = httpx.Timeout( connect=self.config.connect_timeout, read=self.config.read_timeout, write=30.0, pool=10.0, ) # Attempt fetch with retries last_error = None for attempt in range(self.config.max_retries): try: return self._do_fetch( url=url, headers=headers, auth=auth, timeout=timeout, source_name=source_name, start_time=start_time, expected_hash=expected_hash, ) except ( httpx.ConnectError, httpx.ConnectTimeout, httpx.ReadTimeout, httpx.HTTPStatusError, ) as e: last_error = e if self._should_retry(e, attempt): delay = self._calculate_backoff(attempt) logger.warning( f"Fetch failed (attempt {attempt + 1}/{self.config.max_retries}), " f"retrying in {delay:.1f}s: {e}" ) time.sleep(delay) else: break # Convert final error to our exception types self._raise_upstream_error(last_error, url) def _do_fetch( self, url: str, headers: dict, auth: Optional[tuple[str, str]], timeout: httpx.Timeout, source_name: Optional[str], start_time: float, expected_hash: Optional[str] = None, ) -> FetchResult: """ Perform the actual fetch operation. Args: url: URL to fetch. headers: Request headers. auth: Basic auth credentials or None. timeout: Request timeout configuration. source_name: Name of matched source for logging. start_time: Request start time for timing. expected_hash: Optional expected hash for verification. Returns: FetchResult with content and metadata. """ with httpx.Client( timeout=timeout, follow_redirects=self.config.follow_redirects, max_redirects=self.config.max_redirects, verify=self.config.verify_ssl, ) as client: with client.stream("GET", url, headers=headers, auth=auth) as response: # Check for HTTP errors response.raise_for_status() # Check Content-Length against max size content_length = response.headers.get("content-length") if content_length: content_length = int(content_length) if ( self.config.max_file_size and content_length > self.config.max_file_size ): raise FileSizeExceededError( f"File size {content_length} exceeds maximum {self.config.max_file_size}", content_length, self.config.max_file_size, ) # Stream to temp file while computing hash hasher = hashlib.sha256() size = 0 # Create temp file temp_file = tempfile.NamedTemporaryFile( delete=False, prefix="orchard_upstream_" ) temp_path = Path(temp_file.name) try: for chunk in response.iter_bytes(chunk_size=65536): temp_file.write(chunk) hasher.update(chunk) size += len(chunk) # Check size while streaming if max_file_size is set if self.config.max_file_size and size > self.config.max_file_size: temp_file.close() temp_path.unlink() raise FileSizeExceededError( f"Downloaded size {size} exceeds maximum {self.config.max_file_size}", size, self.config.max_file_size, ) temp_file.close() sha256 = hasher.hexdigest() # Verify hash if expected if expected_hash and sha256 != expected_hash.lower(): temp_path.unlink() raise UpstreamError( f"Hash mismatch: expected {expected_hash}, got {sha256}" ) # Capture response headers response_headers = dict(response.headers) # Get content type content_type = response.headers.get("content-type") elapsed = time.time() - start_time logger.info( f"Fetched {url}: {size} bytes, sha256={sha256[:12]}..., " f"source={source_name}, time={elapsed:.2f}s" ) # Return file handle positioned at start content = open(temp_path, "rb") return FetchResult( content=content, sha256=sha256, size=size, content_type=content_type, response_headers=response_headers, source_name=source_name, temp_path=temp_path, ) except Exception: # Clean up on error try: temp_file.close() except Exception: pass if temp_path.exists(): temp_path.unlink() raise def _raise_upstream_error(self, error: Exception, url: str): """ Convert httpx exception to appropriate UpstreamError. Args: error: The httpx exception. url: The URL that was being fetched. Raises: Appropriate UpstreamError subclass. """ if error is None: raise UpstreamError(f"Unknown error fetching {url}") if isinstance(error, httpx.ConnectError): raise UpstreamConnectionError( f"Failed to connect to upstream: {error}" ) from error if isinstance(error, (httpx.ConnectTimeout, httpx.ReadTimeout)): raise UpstreamTimeoutError( f"Request timed out: {error}" ) from error if isinstance(error, httpx.HTTPStatusError): raise UpstreamHTTPError( f"HTTP {error.response.status_code}: {error}", error.response.status_code, dict(error.response.headers), ) from error # Check for SSL errors in the error chain if "ssl" in str(error).lower() or "certificate" in str(error).lower(): raise UpstreamSSLError(f"SSL/TLS error: {error}") from error raise UpstreamError(f"Error fetching {url}: {error}") from error def test_connection(self, source: UpstreamSource) -> tuple[bool, Optional[str], Optional[int]]: """ Test connectivity to an upstream source. Performs a HEAD request to the source URL to verify connectivity and authentication. Does not follow redirects - a 3xx response is considered successful since it proves the server is reachable. Args: source: The upstream source to test. Returns: Tuple of (success, error_message, status_code). """ headers = {"User-Agent": self.config.user_agent} headers.update(self._build_auth_headers(source)) auth = self._get_basic_auth(source) timeout = httpx.Timeout( connect=self.config.connect_timeout, read=30.0, write=30.0, pool=10.0, ) try: with httpx.Client( timeout=timeout, verify=self.config.verify_ssl, ) as client: response = client.head( source.url, headers=headers, auth=auth, follow_redirects=False, ) # Consider 2xx and 3xx as success, also 405 (Method Not Allowed) # since some servers don't support HEAD if response.status_code < 400 or response.status_code == 405: return (True, None, response.status_code) else: return ( False, f"HTTP {response.status_code}", response.status_code, ) except httpx.ConnectError as e: return (False, f"Connection failed: {e}", None) except httpx.ConnectTimeout as e: return (False, f"Connection timed out: {e}", None) except httpx.ReadTimeout as e: return (False, f"Read timed out: {e}", None) except httpx.TooManyRedirects as e: return (False, f"Too many redirects: {e}", None) except Exception as e: return (False, f"Error: {e}", None)