From 35fda65d381acc5ab59bc592ee3013f75906c197 Mon Sep 17 00:00:00 2001 From: Mondo Diaz Date: Wed, 7 Jan 2026 13:36:46 -0600 Subject: [PATCH] Add download verification with SHA256 checksum support (#26, #27, #28, #29) --- CHANGELOG.md | 24 + backend/app/checksum.py | 477 ++++++++++++++ backend/app/config.py | 4 + backend/app/logging_config.py | 254 ++++++++ backend/app/routes.py | 179 +++++- backend/app/storage.py | 96 +++ backend/tests/test_checksum_verification.py | 675 ++++++++++++++++++++ backend/tests/test_download_verification.py | 460 +++++++++++++ 8 files changed, 2157 insertions(+), 12 deletions(-) create mode 100644 backend/app/checksum.py create mode 100644 backend/app/logging_config.py create mode 100644 backend/tests/test_checksum_verification.py create mode 100644 backend/tests/test_download_verification.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 41a8dec..0d057f7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,30 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] ### Added +- Added download verification with `verify` and `verify_mode` query parameters (#26) + - `?verify=true&verify_mode=pre` - Pre-verification: verify before streaming (guaranteed no corrupt data) + - `?verify=true&verify_mode=stream` - Streaming verification: verify while streaming (logs error if mismatch) +- Added checksum response headers to all download endpoints (#27) + - `X-Checksum-SHA256` - SHA256 hash of the artifact + - `X-Content-Length` - File size in bytes + - `X-Checksum-MD5` - MD5 hash (if available) + - `ETag` - Artifact ID (SHA256) + - `Digest` - RFC 3230 format sha-256 hash (base64) + - `X-Verified` - Verification status (true/false/pending) +- Added `checksum.py` module with SHA256 utilities (#26) + - `compute_sha256()` and `compute_sha256_stream()` functions + - `HashingStreamWrapper` for incremental hash computation + - `VerifyingStreamWrapper` for stream verification + - `verify_checksum()` and `verify_checksum_strict()` functions + - `ChecksumMismatchError` exception with context +- Added `get_verified()` and `get_stream_verified()` methods to storage layer (#26) +- Added `logging_config.py` module with structured logging (#28) + - JSON logging format for production + - Request ID tracking via context variables + - Verification failure logging with full context +- Added `log_level` and `log_format` settings to configuration (#28) +- Added 62 unit tests for checksum utilities and verification (#29) +- Added 17 integration tests for download verification API (#29) - Added global artifacts endpoint `GET /api/v1/artifacts` with project/package/tag/size/date filters (#18) - Added global tags endpoint `GET /api/v1/tags` with project/package/search/date filters (#18) - Added wildcard pattern matching (`*`) for tag filters across all endpoints (#18) diff --git a/backend/app/checksum.py b/backend/app/checksum.py new file mode 100644 index 0000000..89b80fc --- /dev/null +++ b/backend/app/checksum.py @@ -0,0 +1,477 @@ +""" +Checksum utilities for download verification. + +This module provides functions and classes for computing and verifying +SHA256 checksums during artifact downloads. + +Key components: +- compute_sha256(): Compute SHA256 of bytes content +- compute_sha256_stream(): Compute SHA256 from an iterable stream +- HashingStreamWrapper: Wrapper that computes hash while streaming +- VerifyingStreamWrapper: Wrapper that verifies hash after streaming +- verify_checksum(): Verify content against expected hash +- ChecksumMismatchError: Exception for verification failures +""" + +import hashlib +import logging +import re +import base64 +from typing import ( + Generator, + Optional, + Any, + Callable, +) + +logger = logging.getLogger(__name__) + +# Default chunk size for streaming operations (8KB) +DEFAULT_CHUNK_SIZE = 8 * 1024 + +# Regex pattern for valid SHA256 hash (64 hex characters) +SHA256_PATTERN = re.compile(r"^[a-fA-F0-9]{64}$") + + +class ChecksumError(Exception): + """Base exception for checksum operations.""" + + pass + + +class ChecksumMismatchError(ChecksumError): + """ + Raised when computed checksum does not match expected checksum. + + Attributes: + expected: The expected SHA256 hash + actual: The actual computed SHA256 hash + artifact_id: Optional artifact ID for context + s3_key: Optional S3 key for debugging + size: Optional file size + """ + + def __init__( + self, + expected: str, + actual: str, + artifact_id: Optional[str] = None, + s3_key: Optional[str] = None, + size: Optional[int] = None, + message: Optional[str] = None, + ): + self.expected = expected + self.actual = actual + self.artifact_id = artifact_id + self.s3_key = s3_key + self.size = size + + if message: + self.message = message + else: + self.message = ( + f"Checksum verification failed: " + f"expected {expected[:16]}..., got {actual[:16]}..." + ) + super().__init__(self.message) + + def to_dict(self) -> dict: + """Convert to dictionary for logging/API responses.""" + return { + "error": "checksum_mismatch", + "expected": self.expected, + "actual": self.actual, + "artifact_id": self.artifact_id, + "s3_key": self.s3_key, + "size": self.size, + "message": self.message, + } + + +class InvalidHashFormatError(ChecksumError): + """Raised when a hash string is not valid SHA256 format.""" + + def __init__(self, hash_value: str): + self.hash_value = hash_value + message = f"Invalid SHA256 hash format: '{hash_value[:32]}...'" + super().__init__(message) + + +def is_valid_sha256(hash_value: str) -> bool: + """ + Check if a string is a valid SHA256 hash (64 hex characters). + + Args: + hash_value: String to validate + + Returns: + True if valid SHA256 format, False otherwise + """ + if not hash_value: + return False + return bool(SHA256_PATTERN.match(hash_value)) + + +def compute_sha256(content: bytes) -> str: + """ + Compute SHA256 hash of bytes content. + + Args: + content: Bytes content to hash + + Returns: + Lowercase hexadecimal SHA256 hash (64 characters) + + Raises: + ChecksumError: If hash computation fails + """ + if content is None: + raise ChecksumError("Cannot compute hash of None content") + + try: + return hashlib.sha256(content).hexdigest().lower() + except Exception as e: + raise ChecksumError(f"Hash computation failed: {e}") from e + + +def compute_sha256_stream( + stream: Any, + chunk_size: int = DEFAULT_CHUNK_SIZE, +) -> str: + """ + Compute SHA256 hash from a stream or file-like object. + + Reads the stream in chunks to minimize memory usage for large files. + + Args: + stream: Iterator yielding bytes or file-like object with read() + chunk_size: Size of chunks to read (default 8KB) + + Returns: + Lowercase hexadecimal SHA256 hash (64 characters) + + Raises: + ChecksumError: If hash computation fails + """ + try: + hasher = hashlib.sha256() + + # Handle file-like objects with read() + if hasattr(stream, "read"): + while True: + chunk = stream.read(chunk_size) + if not chunk: + break + hasher.update(chunk) + else: + # Handle iterators + for chunk in stream: + if chunk: + hasher.update(chunk) + + return hasher.hexdigest().lower() + except Exception as e: + raise ChecksumError(f"Stream hash computation failed: {e}") from e + + +def verify_checksum(content: bytes, expected: str) -> bool: + """ + Verify that content matches expected SHA256 hash. + + Args: + content: Bytes content to verify + expected: Expected SHA256 hash (case-insensitive) + + Returns: + True if hash matches, False otherwise + + Raises: + InvalidHashFormatError: If expected hash is not valid format + ChecksumError: If hash computation fails + """ + if not is_valid_sha256(expected): + raise InvalidHashFormatError(expected) + + actual = compute_sha256(content) + return actual == expected.lower() + + +def verify_checksum_strict( + content: bytes, + expected: str, + artifact_id: Optional[str] = None, + s3_key: Optional[str] = None, +) -> None: + """ + Verify content matches expected hash, raising exception on mismatch. + + Args: + content: Bytes content to verify + expected: Expected SHA256 hash (case-insensitive) + artifact_id: Optional artifact ID for error context + s3_key: Optional S3 key for error context + + Raises: + InvalidHashFormatError: If expected hash is not valid format + ChecksumMismatchError: If verification fails + ChecksumError: If hash computation fails + """ + if not is_valid_sha256(expected): + raise InvalidHashFormatError(expected) + + actual = compute_sha256(content) + if actual != expected.lower(): + raise ChecksumMismatchError( + expected=expected.lower(), + actual=actual, + artifact_id=artifact_id, + s3_key=s3_key, + size=len(content), + ) + + +def sha256_to_base64(hex_hash: str) -> str: + """ + Convert SHA256 hex string to base64 encoding (for RFC 3230 Digest header). + + Args: + hex_hash: SHA256 hash as 64-character hex string + + Returns: + Base64-encoded hash string + """ + if not is_valid_sha256(hex_hash): + raise InvalidHashFormatError(hex_hash) + + hash_bytes = bytes.fromhex(hex_hash) + return base64.b64encode(hash_bytes).decode("ascii") + + +class HashingStreamWrapper: + """ + Wrapper that computes SHA256 hash incrementally as chunks are read. + + This allows computing the hash while streaming content to a client, + without buffering the entire content in memory. + + Usage: + wrapper = HashingStreamWrapper(stream) + for chunk in wrapper: + send_to_client(chunk) + final_hash = wrapper.get_hash() + + Attributes: + chunk_size: Size of chunks to yield + bytes_read: Total bytes processed so far + """ + + def __init__( + self, + stream: Any, + chunk_size: int = DEFAULT_CHUNK_SIZE, + ): + """ + Initialize the hashing stream wrapper. + + Args: + stream: Source stream (iterator, file-like, or S3 StreamingBody) + chunk_size: Size of chunks to yield (default 8KB) + """ + self._stream = stream + self._hasher = hashlib.sha256() + self._chunk_size = chunk_size + self._bytes_read = 0 + self._finalized = False + self._final_hash: Optional[str] = None + + @property + def bytes_read(self) -> int: + """Total bytes read so far.""" + return self._bytes_read + + @property + def chunk_size(self) -> int: + """Chunk size for reading.""" + return self._chunk_size + + def __iter__(self) -> Generator[bytes, None, None]: + """Iterate over chunks, computing hash as we go.""" + # Handle S3 StreamingBody (has iter_chunks) + if hasattr(self._stream, "iter_chunks"): + for chunk in self._stream.iter_chunks(chunk_size=self._chunk_size): + if chunk: + self._hasher.update(chunk) + self._bytes_read += len(chunk) + yield chunk + # Handle file-like objects with read() + elif hasattr(self._stream, "read"): + while True: + chunk = self._stream.read(self._chunk_size) + if not chunk: + break + self._hasher.update(chunk) + self._bytes_read += len(chunk) + yield chunk + # Handle iterators + else: + for chunk in self._stream: + if chunk: + self._hasher.update(chunk) + self._bytes_read += len(chunk) + yield chunk + + self._finalized = True + self._final_hash = self._hasher.hexdigest().lower() + + def get_hash(self) -> str: + """ + Get the computed SHA256 hash. + + If stream hasn't been fully consumed, consumes remaining chunks. + + Returns: + Lowercase hexadecimal SHA256 hash + """ + if not self._finalized: + # Consume remaining stream + for _ in self: + pass + + return self._final_hash or self._hasher.hexdigest().lower() + + def get_hash_if_complete(self) -> Optional[str]: + """ + Get hash only if stream has been fully consumed. + + Returns: + Hash if complete, None otherwise + """ + if self._finalized: + return self._final_hash + return None + + +class VerifyingStreamWrapper: + """ + Wrapper that yields chunks and verifies hash after streaming completes. + + IMPORTANT: Because HTTP streams cannot be "un-sent", if verification + fails after streaming, the client has already received potentially + corrupt data. This wrapper logs an error but cannot prevent delivery. + + For guaranteed verification before delivery, use pre-verification mode + which buffers the entire content first. + + Usage: + wrapper = VerifyingStreamWrapper(stream, expected_hash) + for chunk in wrapper: + send_to_client(chunk) + wrapper.verify() # Raises ChecksumMismatchError if failed + """ + + def __init__( + self, + stream: Any, + expected_hash: str, + artifact_id: Optional[str] = None, + s3_key: Optional[str] = None, + chunk_size: int = DEFAULT_CHUNK_SIZE, + on_failure: Optional[Callable[[Any], None]] = None, + ): + """ + Initialize the verifying stream wrapper. + + Args: + stream: Source stream + expected_hash: Expected SHA256 hash to verify against + artifact_id: Optional artifact ID for error context + s3_key: Optional S3 key for error context + chunk_size: Size of chunks to yield + on_failure: Optional callback called on verification failure + """ + if not is_valid_sha256(expected_hash): + raise InvalidHashFormatError(expected_hash) + + self._hashing_wrapper = HashingStreamWrapper(stream, chunk_size) + self._expected_hash = expected_hash.lower() + self._artifact_id = artifact_id + self._s3_key = s3_key + self._on_failure = on_failure + self._verified: Optional[bool] = None + + @property + def bytes_read(self) -> int: + """Total bytes read so far.""" + return self._hashing_wrapper.bytes_read + + @property + def is_verified(self) -> Optional[bool]: + """ + Verification status. + + Returns: + True if verified successfully, False if failed, None if not yet complete + """ + return self._verified + + def __iter__(self) -> Generator[bytes, None, None]: + """Iterate over chunks.""" + yield from self._hashing_wrapper + + def verify(self) -> bool: + """ + Verify the hash after stream is complete. + + Must be called after fully consuming the iterator. + + Returns: + True if verification passed + + Raises: + ChecksumMismatchError: If verification failed + """ + actual_hash = self._hashing_wrapper.get_hash() + + if actual_hash == self._expected_hash: + self._verified = True + logger.debug( + f"Verification passed for {self._artifact_id or 'unknown'}: {actual_hash[:16]}..." + ) + return True + + self._verified = False + error = ChecksumMismatchError( + expected=self._expected_hash, + actual=actual_hash, + artifact_id=self._artifact_id, + s3_key=self._s3_key, + size=self._hashing_wrapper.bytes_read, + ) + + # Log the failure + logger.error(f"Checksum verification FAILED after streaming: {error.to_dict()}") + + # Call failure callback if provided + if self._on_failure: + try: + self._on_failure(error) + except Exception as e: + logger.warning(f"Verification failure callback raised exception: {e}") + + raise error + + def verify_silent(self) -> bool: + """ + Verify the hash without raising exception. + + Returns: + True if verification passed, False otherwise + """ + try: + return self.verify() + except ChecksumMismatchError: + return False + + def get_actual_hash(self) -> Optional[str]: + """Get the actual computed hash (only available after iteration).""" + return self._hashing_wrapper.get_hash_if_complete() diff --git a/backend/app/config.py b/backend/app/config.py index 2aa4469..fa78674 100644 --- a/backend/app/config.py +++ b/backend/app/config.py @@ -48,6 +48,10 @@ class Settings(BaseSettings): 3600 # Presigned URL expiry in seconds (default: 1 hour) ) + # Logging settings + log_level: str = "INFO" # DEBUG, INFO, WARNING, ERROR, CRITICAL + log_format: str = "auto" # "json", "standard", or "auto" (json in production) + @property def database_url(self) -> str: sslmode = f"?sslmode={self.database_sslmode}" if self.database_sslmode else "" diff --git a/backend/app/logging_config.py b/backend/app/logging_config.py new file mode 100644 index 0000000..558e17b --- /dev/null +++ b/backend/app/logging_config.py @@ -0,0 +1,254 @@ +""" +Structured logging configuration for Orchard. + +This module provides: +- Structured JSON logging for production environments +- Request tracing via X-Request-ID header +- Verification failure logging with context +- Configurable log levels via environment + +Usage: + from app.logging_config import setup_logging, get_request_id + + setup_logging() # Call once at app startup + request_id = get_request_id() # Get current request's ID +""" + +import logging +import json +import sys +import uuid +from datetime import datetime, timezone +from typing import Optional, Any, Dict +from contextvars import ContextVar + +from .config import get_settings + +# Context variable for request ID (thread-safe) +_request_id_var: ContextVar[Optional[str]] = ContextVar("request_id", default=None) + + +def get_request_id() -> Optional[str]: + """Get the current request's ID from context.""" + return _request_id_var.get() + + +def set_request_id(request_id: Optional[str] = None) -> str: + """ + Set the request ID for the current context. + + If no ID provided, generates a new UUID. + Returns the request ID that was set. + """ + if request_id is None: + request_id = str(uuid.uuid4()) + _request_id_var.set(request_id) + return request_id + + +def clear_request_id(): + """Clear the request ID from context.""" + _request_id_var.set(None) + + +class JSONFormatter(logging.Formatter): + """ + JSON log formatter for structured logging. + + Output format: + { + "timestamp": "2025-01-01T00:00:00.000Z", + "level": "INFO", + "logger": "app.routes", + "message": "Request completed", + "request_id": "abc-123", + "extra": {...} + } + """ + + def format(self, record: logging.LogRecord) -> str: + log_entry: Dict[str, Any] = { + "timestamp": datetime.now(timezone.utc).isoformat(), + "level": record.levelname, + "logger": record.name, + "message": record.getMessage(), + } + + # Add request ID if available + request_id = get_request_id() + if request_id: + log_entry["request_id"] = request_id + + # Add exception info if present + if record.exc_info: + log_entry["exception"] = self.formatException(record.exc_info) + + # Add extra fields from record + extra_fields: Dict[str, Any] = {} + for key, value in record.__dict__.items(): + if key not in ( + "name", + "msg", + "args", + "created", + "filename", + "funcName", + "levelname", + "levelno", + "lineno", + "module", + "msecs", + "pathname", + "process", + "processName", + "relativeCreated", + "stack_info", + "exc_info", + "exc_text", + "thread", + "threadName", + "message", + "asctime", + ): + try: + json.dumps(value) # Ensure serializable + extra_fields[key] = value + except (TypeError, ValueError): + extra_fields[key] = str(value) + + if extra_fields: + log_entry["extra"] = extra_fields + + return json.dumps(log_entry) + + +class StandardFormatter(logging.Formatter): + """ + Standard log formatter for development. + + Output format: + [2025-01-01 00:00:00] INFO [app.routes] [req-abc123] Request completed + """ + + def format(self, record: logging.LogRecord) -> str: + timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + request_id = get_request_id() + req_str = f" [req-{request_id[:8]}]" if request_id else "" + + base_msg = f"[{timestamp}] {record.levelname:5} [{record.name}]{req_str} {record.getMessage()}" + + if record.exc_info: + base_msg += "\n" + self.formatException(record.exc_info) + + return base_msg + + +def setup_logging(log_level: Optional[str] = None, json_format: Optional[bool] = None): + """ + Configure logging for the application. + + Args: + log_level: Log level (DEBUG, INFO, WARNING, ERROR, CRITICAL). + Defaults to ORCHARD_LOG_LEVEL env var or INFO. + json_format: Use JSON format. Defaults to True in production. + """ + settings = get_settings() + + # Determine log level + if log_level is None: + log_level = getattr(settings, "log_level", "INFO") + effective_level = log_level if log_level else "INFO" + level = getattr(logging, effective_level.upper(), logging.INFO) + + # Determine format + if json_format is None: + json_format = settings.is_production + + # Create handler + handler = logging.StreamHandler(sys.stdout) + handler.setLevel(level) + + # Set formatter + if json_format: + handler.setFormatter(JSONFormatter()) + else: + handler.setFormatter(StandardFormatter()) + + # Configure root logger + root_logger = logging.getLogger() + root_logger.setLevel(level) + + # Remove existing handlers + root_logger.handlers.clear() + root_logger.addHandler(handler) + + # Configure specific loggers + for logger_name in ["app", "uvicorn", "uvicorn.access", "uvicorn.error"]: + logger = logging.getLogger(logger_name) + logger.setLevel(level) + logger.handlers.clear() + logger.addHandler(handler) + logger.propagate = False + + # Quiet down noisy loggers + logging.getLogger("botocore").setLevel(logging.WARNING) + logging.getLogger("boto3").setLevel(logging.WARNING) + logging.getLogger("urllib3").setLevel(logging.WARNING) + + +def log_verification_failure( + logger: logging.Logger, + expected_hash: str, + actual_hash: str, + artifact_id: Optional[str] = None, + s3_key: Optional[str] = None, + project: Optional[str] = None, + package: Optional[str] = None, + size: Optional[int] = None, + user_id: Optional[str] = None, + source_ip: Optional[str] = None, + verification_mode: Optional[str] = None, +): + """ + Log a verification failure with full context. + + This creates a structured log entry with all relevant details + for debugging and alerting. + """ + logger.error( + "Checksum verification failed", + extra={ + "event": "verification_failure", + "expected_hash": expected_hash, + "actual_hash": actual_hash, + "artifact_id": artifact_id, + "s3_key": s3_key, + "project": project, + "package": package, + "size": size, + "user_id": user_id, + "source_ip": source_ip, + "verification_mode": verification_mode, + "hash_match": expected_hash == actual_hash, + }, + ) + + +def log_verification_success( + logger: logging.Logger, + artifact_id: str, + size: Optional[int] = None, + verification_mode: Optional[str] = None, + duration_ms: Optional[float] = None, +): + """Log a successful verification.""" + logger.info( + f"Verification passed for artifact {artifact_id[:16]}...", + extra={ + "event": "verification_success", + "artifact_id": artifact_id, + "size": size, + "verification_mode": verification_mode, + "duration_ms": duration_ms, + }, + ) diff --git a/backend/app/routes.py b/backend/app/routes.py index 7975746..ff6603f 100644 --- a/backend/app/routes.py +++ b/backend/app/routes.py @@ -97,6 +97,11 @@ from .schemas import ( ) from .metadata import extract_metadata from .config import get_settings +from .checksum import ( + ChecksumMismatchError, + VerifyingStreamWrapper, + sha256_to_base64, +) router = APIRouter() @@ -1777,7 +1782,7 @@ def _resolve_artifact_ref( return artifact -# Download artifact with range request support and download modes +# Download artifact with range request support, download modes, and verification @router.get("/api/v1/project/{project_name}/{package_name}/+/{ref}") def download_artifact( project_name: str, @@ -1791,7 +1796,34 @@ def download_artifact( default=None, description="Download mode: proxy (stream through backend), redirect (302 to presigned URL), presigned (return JSON with URL)", ), + verify: bool = Query( + default=False, + description="Enable checksum verification during download", + ), + verify_mode: Optional[Literal["stream", "pre"]] = Query( + default="stream", + description="Verification mode: 'stream' (verify after streaming, logs error if mismatch), 'pre' (verify before streaming, returns 500 if mismatch)", + ), ): + """ + Download an artifact by reference (tag name, artifact:hash, tag:name). + + Verification modes: + - verify=false (default): No verification, maximum performance + - verify=true&verify_mode=stream: Compute hash while streaming, verify after completion. + If mismatch, logs error but content already sent. + - verify=true&verify_mode=pre: Download and verify BEFORE streaming to client. + Higher latency but guarantees no corrupt data sent. + + Response headers always include: + - X-Checksum-SHA256: The expected SHA256 hash + - X-Content-Length: File size in bytes + - ETag: Artifact ID (SHA256) + - Digest: RFC 3230 format sha-256 hash + + When verify=true: + - X-Verified: 'true' if verified, 'false' if verification failed + """ settings = get_settings() # Get project and package @@ -1831,6 +1863,23 @@ def download_artifact( ) db.commit() + # Build common checksum headers (always included) + checksum_headers = { + "X-Checksum-SHA256": artifact.id, + "X-Content-Length": str(artifact.size), + "ETag": f'"{artifact.id}"', + } + # Add RFC 3230 Digest header + try: + digest_base64 = sha256_to_base64(artifact.id) + checksum_headers["Digest"] = f"sha-256={digest_base64}" + except Exception: + pass # Skip if conversion fails + + # Add MD5 checksum if available + if artifact.checksum_md5: + checksum_headers["X-Checksum-MD5"] = artifact.checksum_md5 + # Determine download mode (query param overrides server default) download_mode = mode or settings.download_mode @@ -1867,7 +1916,7 @@ def download_artifact( return RedirectResponse(url=presigned_url, status_code=302) # Proxy mode (default fallback) - stream through backend - # Handle range requests + # Handle range requests (verification not supported for partial downloads) if range: stream, content_length, content_range = storage.get_stream( artifact.s3_key, range @@ -1877,9 +1926,11 @@ def download_artifact( "Content-Disposition": f'attachment; filename="{filename}"', "Accept-Ranges": "bytes", "Content-Length": str(content_length), + **checksum_headers, } if content_range: headers["Content-Range"] = content_range + # Note: X-Verified not set for range requests (cannot verify partial content) return StreamingResponse( stream, @@ -1888,16 +1939,88 @@ def download_artifact( headers=headers, ) - # Full download + # Full download with optional verification + base_headers = { + "Content-Disposition": f'attachment; filename="{filename}"', + "Accept-Ranges": "bytes", + **checksum_headers, + } + + # Pre-verification mode: verify before streaming + if verify and verify_mode == "pre": + try: + content = storage.get_verified(artifact.s3_key, artifact.id) + return Response( + content=content, + media_type=artifact.content_type or "application/octet-stream", + headers={ + **base_headers, + "Content-Length": str(len(content)), + "X-Verified": "true", + }, + ) + except ChecksumMismatchError as e: + logger.error( + f"Pre-verification failed for artifact {artifact.id[:16]}...: {e.to_dict()}" + ) + raise HTTPException( + status_code=500, + detail={ + "error": "checksum_verification_failed", + "message": "Downloaded content does not match expected checksum", + "expected": e.expected, + "actual": e.actual, + "artifact_id": artifact.id, + }, + ) + + # Streaming verification mode: verify while/after streaming + if verify and verify_mode == "stream": + verifying_wrapper, content_length, _ = storage.get_stream_verified( + artifact.s3_key, artifact.id + ) + + def verified_stream(): + """Generator that yields chunks and verifies after completion.""" + try: + for chunk in verifying_wrapper: + yield chunk + # After all chunks yielded, verify + try: + verifying_wrapper.verify() + logger.info( + f"Streaming verification passed for artifact {artifact.id[:16]}..." + ) + except ChecksumMismatchError as e: + # Content already sent - log error but cannot reject + logger.error( + f"Streaming verification FAILED for artifact {artifact.id[:16]}...: " + f"expected {e.expected[:16]}..., got {e.actual[:16]}..." + ) + except Exception as e: + logger.error(f"Error during streaming download: {e}") + raise + + return StreamingResponse( + verified_stream(), + media_type=artifact.content_type or "application/octet-stream", + headers={ + **base_headers, + "Content-Length": str(content_length), + "X-Verified": "pending", # Verification happens after streaming + }, + ) + + # No verification - direct streaming stream, content_length, _ = storage.get_stream(artifact.s3_key) return StreamingResponse( stream, media_type=artifact.content_type or "application/octet-stream", headers={ - "Content-Disposition": f'attachment; filename="{filename}"', - "Accept-Ranges": "bytes", + **base_headers, "Content-Length": str(content_length), + "X-Verified": "false", }, ) @@ -1975,6 +2098,11 @@ def head_artifact( db: Session = Depends(get_db), storage: S3Storage = Depends(get_storage), ): + """ + Get artifact metadata without downloading content. + + Returns headers with checksum information for client-side verification. + """ # Get project and package project = db.query(Project).filter(Project.name == project_name).first() if not project: @@ -1995,15 +2123,32 @@ def head_artifact( filename = sanitize_filename(artifact.original_name or f"{artifact.id}") + # Build headers with checksum information + headers = { + "Content-Disposition": f'attachment; filename="{filename}"', + "Accept-Ranges": "bytes", + "Content-Length": str(artifact.size), + "X-Artifact-Id": artifact.id, + "X-Checksum-SHA256": artifact.id, + "X-Content-Length": str(artifact.size), + "ETag": f'"{artifact.id}"', + } + + # Add RFC 3230 Digest header + try: + digest_base64 = sha256_to_base64(artifact.id) + headers["Digest"] = f"sha-256={digest_base64}" + except Exception: + pass # Skip if conversion fails + + # Add MD5 checksum if available + if artifact.checksum_md5: + headers["X-Checksum-MD5"] = artifact.checksum_md5 + return Response( content=b"", media_type=artifact.content_type or "application/octet-stream", - headers={ - "Content-Disposition": f'attachment; filename="{filename}"', - "Accept-Ranges": "bytes", - "Content-Length": str(artifact.size), - "X-Artifact-Id": artifact.id, - }, + headers=headers, ) @@ -2017,9 +2162,19 @@ def download_artifact_compat( db: Session = Depends(get_db), storage: S3Storage = Depends(get_storage), range: Optional[str] = Header(None), + verify: bool = Query(default=False), + verify_mode: Optional[Literal["stream", "pre"]] = Query(default="stream"), ): return download_artifact( - project_name, package_name, ref, request, db, storage, range + project_name, + package_name, + ref, + request, + db, + storage, + range, + verify=verify, + verify_mode=verify_mode, ) diff --git a/backend/app/storage.py b/backend/app/storage.py index 440dbaf..672d841 100644 --- a/backend/app/storage.py +++ b/backend/app/storage.py @@ -22,6 +22,13 @@ from botocore.exceptions import ( ) from .config import get_settings +from .checksum import ( + ChecksumMismatchError, + HashingStreamWrapper, + VerifyingStreamWrapper, + compute_sha256, + is_valid_sha256, +) settings = get_settings() logger = logging.getLogger(__name__) @@ -876,6 +883,95 @@ class S3Storage: logger.error(f"Unexpected error during storage health check: {e}") return False + def get_verified(self, s3_key: str, expected_hash: str) -> bytes: + """ + Download and verify content matches expected SHA256 hash. + + This method downloads the entire content, computes its hash, and + verifies it matches the expected hash before returning. + + Args: + s3_key: The S3 storage key of the file + expected_hash: Expected SHA256 hash (64 hex characters) + + Returns: + File content as bytes (only if verification passes) + + Raises: + ChecksumMismatchError: If computed hash doesn't match expected + ClientError: If S3 operation fails + """ + if not is_valid_sha256(expected_hash): + raise ValueError(f"Invalid SHA256 hash format: {expected_hash}") + + content = self.get(s3_key) + actual_hash = compute_sha256(content) + + if actual_hash != expected_hash.lower(): + raise ChecksumMismatchError( + expected=expected_hash.lower(), + actual=actual_hash, + s3_key=s3_key, + size=len(content), + ) + + logger.debug(f"Verification passed for {s3_key}: {actual_hash[:16]}...") + return content + + def get_stream_verified( + self, + s3_key: str, + expected_hash: str, + range_header: Optional[str] = None, + ) -> Tuple[VerifyingStreamWrapper, int, Optional[str]]: + """ + Get a verifying stream wrapper for an object. + + Returns a wrapper that computes the hash as chunks are read and + can verify after streaming completes. Note that verification happens + AFTER content has been streamed to the client. + + IMPORTANT: For range requests, verification is not supported because + we cannot verify a partial download against the full file hash. + + Args: + s3_key: The S3 storage key of the file + expected_hash: Expected SHA256 hash (64 hex characters) + range_header: Optional HTTP Range header (verification disabled if set) + + Returns: + Tuple of (VerifyingStreamWrapper, content_length, content_range) + The wrapper has a verify() method to call after streaming. + + Raises: + ValueError: If expected_hash is invalid format + ClientError: If S3 operation fails + """ + if not is_valid_sha256(expected_hash): + raise ValueError(f"Invalid SHA256 hash format: {expected_hash}") + + # Get the S3 stream + stream, content_length, content_range = self.get_stream(s3_key, range_header) + + # For range requests, we cannot verify (partial content) + # Return a HashingStreamWrapper that just tracks bytes without verification + if range_header or content_range: + logger.debug( + f"Range request for {s3_key} - verification disabled (partial content)" + ) + # Return a basic hashing wrapper (caller should not verify) + hashing_wrapper = HashingStreamWrapper(stream) + return hashing_wrapper, content_length, content_range + + # Create verifying wrapper + verifying_wrapper = VerifyingStreamWrapper( + stream=stream, + expected_hash=expected_hash, + s3_key=s3_key, + ) + + return verifying_wrapper, content_length, content_range + def verify_integrity(self, s3_key: str, expected_sha256: str) -> bool: """ Verify the integrity of a stored object by downloading and re-hashing. diff --git a/backend/tests/test_checksum_verification.py b/backend/tests/test_checksum_verification.py new file mode 100644 index 0000000..049508b --- /dev/null +++ b/backend/tests/test_checksum_verification.py @@ -0,0 +1,675 @@ +""" +Tests for checksum calculation, verification, and download verification. + +This module tests: +- SHA256 hash computation (bytes and streams) +- HashingStreamWrapper incremental hashing +- VerifyingStreamWrapper with verification +- ChecksumMismatchError exception handling +- Download verification API endpoints +""" + +import pytest +import hashlib +import io +from typing import Generator + +from app.checksum import ( + compute_sha256, + compute_sha256_stream, + verify_checksum, + verify_checksum_strict, + is_valid_sha256, + sha256_to_base64, + HashingStreamWrapper, + VerifyingStreamWrapper, + ChecksumMismatchError, + ChecksumError, + InvalidHashFormatError, + DEFAULT_CHUNK_SIZE, +) + + +# ============================================================================= +# Test Data +# ============================================================================= + +# Known test vectors +TEST_CONTENT_HELLO = b"Hello, World!" +TEST_HASH_HELLO = "dffd6021bb2bd5b0af676290809ec3a53191dd81c7f70a4b28688a362182986f" + +TEST_CONTENT_EMPTY = b"" +TEST_HASH_EMPTY = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" + +TEST_CONTENT_BINARY = bytes(range(256)) +TEST_HASH_BINARY = hashlib.sha256(TEST_CONTENT_BINARY).hexdigest() + +# Invalid hashes for testing +INVALID_HASH_TOO_SHORT = "abcd1234" +INVALID_HASH_TOO_LONG = "a" * 65 +INVALID_HASH_NON_HEX = "zzzz" + "a" * 60 +INVALID_HASH_EMPTY = "" + + +# ============================================================================= +# Unit Tests - SHA256 Computation +# ============================================================================= + + +class TestComputeSHA256: + """Tests for compute_sha256 function.""" + + def test_known_content_matches_expected_hash(self): + """Test SHA256 of known content matches pre-computed hash.""" + result = compute_sha256(TEST_CONTENT_HELLO) + assert result == TEST_HASH_HELLO + + def test_returns_64_character_hex_string(self): + """Test result is exactly 64 hex characters.""" + result = compute_sha256(TEST_CONTENT_HELLO) + assert len(result) == 64 + assert all(c in "0123456789abcdef" for c in result) + + def test_returns_lowercase_hex(self): + """Test result is lowercase.""" + result = compute_sha256(TEST_CONTENT_HELLO) + assert result == result.lower() + + def test_empty_content_returns_empty_hash(self): + """Test empty bytes returns SHA256 of empty content.""" + result = compute_sha256(TEST_CONTENT_EMPTY) + assert result == TEST_HASH_EMPTY + + def test_deterministic_same_input_same_output(self): + """Test same input always produces same output.""" + content = b"test content for determinism" + result1 = compute_sha256(content) + result2 = compute_sha256(content) + assert result1 == result2 + + def test_different_content_different_hash(self): + """Test different content produces different hash.""" + hash1 = compute_sha256(b"content A") + hash2 = compute_sha256(b"content B") + assert hash1 != hash2 + + def test_single_bit_change_different_hash(self): + """Test single bit change produces completely different hash.""" + content1 = b"\x00" * 100 + content2 = b"\x00" * 99 + b"\x01" + hash1 = compute_sha256(content1) + hash2 = compute_sha256(content2) + assert hash1 != hash2 + + def test_binary_content(self): + """Test hashing binary content with all byte values.""" + result = compute_sha256(TEST_CONTENT_BINARY) + assert result == TEST_HASH_BINARY + assert len(result) == 64 + + def test_large_content(self): + """Test hashing larger content (1MB).""" + large_content = b"x" * (1024 * 1024) + result = compute_sha256(large_content) + expected = hashlib.sha256(large_content).hexdigest() + assert result == expected + + def test_none_content_raises_error(self): + """Test None content raises ChecksumError.""" + with pytest.raises(ChecksumError, match="Cannot compute hash of None"): + compute_sha256(None) + + +class TestComputeSHA256Stream: + """Tests for compute_sha256_stream function.""" + + def test_file_like_object(self): + """Test hashing from file-like object.""" + file_obj = io.BytesIO(TEST_CONTENT_HELLO) + result = compute_sha256_stream(file_obj) + assert result == TEST_HASH_HELLO + + def test_iterator(self): + """Test hashing from iterator of chunks.""" + + def chunk_iterator(): + yield b"Hello, " + yield b"World!" + + result = compute_sha256_stream(chunk_iterator()) + assert result == TEST_HASH_HELLO + + def test_various_chunk_sizes_same_result(self): + """Test different chunk sizes produce same hash.""" + content = b"x" * 10000 + expected = hashlib.sha256(content).hexdigest() + + for chunk_size in [1, 10, 100, 1000, 8192]: + file_obj = io.BytesIO(content) + result = compute_sha256_stream(file_obj, chunk_size=chunk_size) + assert result == expected, f"Failed for chunk_size={chunk_size}" + + def test_single_byte_chunks(self): + """Test with 1-byte chunks (edge case).""" + content = b"ABC" + file_obj = io.BytesIO(content) + result = compute_sha256_stream(file_obj, chunk_size=1) + expected = hashlib.sha256(content).hexdigest() + assert result == expected + + def test_empty_stream(self): + """Test empty stream returns empty content hash.""" + file_obj = io.BytesIO(b"") + result = compute_sha256_stream(file_obj) + assert result == TEST_HASH_EMPTY + + +# ============================================================================= +# Unit Tests - Hash Validation +# ============================================================================= + + +class TestIsValidSHA256: + """Tests for is_valid_sha256 function.""" + + def test_valid_hash_lowercase(self): + """Test valid lowercase hash.""" + assert is_valid_sha256(TEST_HASH_HELLO) is True + + def test_valid_hash_uppercase(self): + """Test valid uppercase hash.""" + assert is_valid_sha256(TEST_HASH_HELLO.upper()) is True + + def test_valid_hash_mixed_case(self): + """Test valid mixed case hash.""" + mixed = TEST_HASH_HELLO[:32].upper() + TEST_HASH_HELLO[32:].lower() + assert is_valid_sha256(mixed) is True + + def test_invalid_too_short(self): + """Test hash that's too short.""" + assert is_valid_sha256(INVALID_HASH_TOO_SHORT) is False + + def test_invalid_too_long(self): + """Test hash that's too long.""" + assert is_valid_sha256(INVALID_HASH_TOO_LONG) is False + + def test_invalid_non_hex(self): + """Test hash with non-hex characters.""" + assert is_valid_sha256(INVALID_HASH_NON_HEX) is False + + def test_invalid_empty(self): + """Test empty string.""" + assert is_valid_sha256(INVALID_HASH_EMPTY) is False + + def test_invalid_none(self): + """Test None value.""" + assert is_valid_sha256(None) is False + + +class TestSHA256ToBase64: + """Tests for sha256_to_base64 function.""" + + def test_converts_to_base64(self): + """Test conversion to base64.""" + result = sha256_to_base64(TEST_HASH_HELLO) + # Verify it's valid base64 + import base64 + + decoded = base64.b64decode(result) + assert len(decoded) == 32 # SHA256 is 32 bytes + + def test_invalid_hash_raises_error(self): + """Test invalid hash raises InvalidHashFormatError.""" + with pytest.raises(InvalidHashFormatError): + sha256_to_base64(INVALID_HASH_TOO_SHORT) + + +# ============================================================================= +# Unit Tests - Verification Functions +# ============================================================================= + + +class TestVerifyChecksum: + """Tests for verify_checksum function.""" + + def test_matching_checksum_returns_true(self): + """Test matching checksum returns True.""" + result = verify_checksum(TEST_CONTENT_HELLO, TEST_HASH_HELLO) + assert result is True + + def test_mismatched_checksum_returns_false(self): + """Test mismatched checksum returns False.""" + wrong_hash = "a" * 64 + result = verify_checksum(TEST_CONTENT_HELLO, wrong_hash) + assert result is False + + def test_case_insensitive_comparison(self): + """Test comparison is case-insensitive.""" + result = verify_checksum(TEST_CONTENT_HELLO, TEST_HASH_HELLO.upper()) + assert result is True + + def test_invalid_hash_format_raises_error(self): + """Test invalid hash format raises error.""" + with pytest.raises(InvalidHashFormatError): + verify_checksum(TEST_CONTENT_HELLO, INVALID_HASH_TOO_SHORT) + + +class TestVerifyChecksumStrict: + """Tests for verify_checksum_strict function.""" + + def test_matching_checksum_returns_none(self): + """Test matching checksum doesn't raise.""" + # Should not raise + verify_checksum_strict(TEST_CONTENT_HELLO, TEST_HASH_HELLO) + + def test_mismatched_checksum_raises_error(self): + """Test mismatched checksum raises ChecksumMismatchError.""" + wrong_hash = "a" * 64 + with pytest.raises(ChecksumMismatchError) as exc_info: + verify_checksum_strict(TEST_CONTENT_HELLO, wrong_hash) + + error = exc_info.value + assert error.expected == wrong_hash.lower() + assert error.actual == TEST_HASH_HELLO + assert error.size == len(TEST_CONTENT_HELLO) + + def test_error_includes_context(self): + """Test error includes artifact_id and s3_key context.""" + wrong_hash = "a" * 64 + with pytest.raises(ChecksumMismatchError) as exc_info: + verify_checksum_strict( + TEST_CONTENT_HELLO, + wrong_hash, + artifact_id="test-artifact-123", + s3_key="fruits/ab/cd/abcd1234...", + ) + + error = exc_info.value + assert error.artifact_id == "test-artifact-123" + assert error.s3_key == "fruits/ab/cd/abcd1234..." + + +# ============================================================================= +# Unit Tests - HashingStreamWrapper +# ============================================================================= + + +class TestHashingStreamWrapper: + """Tests for HashingStreamWrapper class.""" + + def test_computes_correct_hash(self): + """Test wrapper computes correct hash.""" + stream = io.BytesIO(TEST_CONTENT_HELLO) + wrapper = HashingStreamWrapper(stream) + + # Consume the stream + chunks = list(wrapper) + + # Verify hash + assert wrapper.get_hash() == TEST_HASH_HELLO + + def test_yields_correct_chunks(self): + """Test wrapper yields all content.""" + stream = io.BytesIO(TEST_CONTENT_HELLO) + wrapper = HashingStreamWrapper(stream) + + chunks = list(wrapper) + content = b"".join(chunks) + + assert content == TEST_CONTENT_HELLO + + def test_tracks_bytes_read(self): + """Test bytes_read property tracks correctly.""" + stream = io.BytesIO(TEST_CONTENT_HELLO) + wrapper = HashingStreamWrapper(stream) + + assert wrapper.bytes_read == 0 + list(wrapper) # Consume + assert wrapper.bytes_read == len(TEST_CONTENT_HELLO) + + def test_get_hash_before_iteration_consumes_stream(self): + """Test get_hash() consumes stream if not already done.""" + stream = io.BytesIO(TEST_CONTENT_HELLO) + wrapper = HashingStreamWrapper(stream) + + # Call get_hash without iterating + hash_result = wrapper.get_hash() + + assert hash_result == TEST_HASH_HELLO + assert wrapper.bytes_read == len(TEST_CONTENT_HELLO) + + def test_get_hash_if_complete_before_iteration_returns_none(self): + """Test get_hash_if_complete returns None before iteration.""" + stream = io.BytesIO(TEST_CONTENT_HELLO) + wrapper = HashingStreamWrapper(stream) + + assert wrapper.get_hash_if_complete() is None + + def test_get_hash_if_complete_after_iteration_returns_hash(self): + """Test get_hash_if_complete returns hash after iteration.""" + stream = io.BytesIO(TEST_CONTENT_HELLO) + wrapper = HashingStreamWrapper(stream) + + list(wrapper) # Consume + assert wrapper.get_hash_if_complete() == TEST_HASH_HELLO + + def test_custom_chunk_size(self): + """Test custom chunk size is respected.""" + content = b"x" * 1000 + stream = io.BytesIO(content) + wrapper = HashingStreamWrapper(stream, chunk_size=100) + + chunks = list(wrapper) + + # Each chunk should be at most 100 bytes + for chunk in chunks[:-1]: # All but last + assert len(chunk) == 100 + + # Total content should match + assert b"".join(chunks) == content + + def test_iterator_interface(self): + """Test wrapper supports iterator interface.""" + stream = io.BytesIO(TEST_CONTENT_HELLO) + wrapper = HashingStreamWrapper(stream) + + # Should be able to use for loop + result = b"" + for chunk in wrapper: + result += chunk + + assert result == TEST_CONTENT_HELLO + + +# ============================================================================= +# Unit Tests - VerifyingStreamWrapper +# ============================================================================= + + +class TestVerifyingStreamWrapper: + """Tests for VerifyingStreamWrapper class.""" + + def test_verify_success(self): + """Test verification succeeds for matching content.""" + stream = io.BytesIO(TEST_CONTENT_HELLO) + wrapper = VerifyingStreamWrapper(stream, TEST_HASH_HELLO) + + # Consume stream + list(wrapper) + + # Verify should succeed + result = wrapper.verify() + assert result is True + assert wrapper.is_verified is True + + def test_verify_failure_raises_error(self): + """Test verification failure raises ChecksumMismatchError.""" + stream = io.BytesIO(TEST_CONTENT_HELLO) + wrong_hash = "a" * 64 + wrapper = VerifyingStreamWrapper(stream, wrong_hash) + + # Consume stream + list(wrapper) + + # Verify should fail + with pytest.raises(ChecksumMismatchError): + wrapper.verify() + + assert wrapper.is_verified is False + + def test_verify_silent_success(self): + """Test verify_silent returns True on success.""" + stream = io.BytesIO(TEST_CONTENT_HELLO) + wrapper = VerifyingStreamWrapper(stream, TEST_HASH_HELLO) + + list(wrapper) + + result = wrapper.verify_silent() + assert result is True + + def test_verify_silent_failure(self): + """Test verify_silent returns False on failure.""" + stream = io.BytesIO(TEST_CONTENT_HELLO) + wrong_hash = "a" * 64 + wrapper = VerifyingStreamWrapper(stream, wrong_hash) + + list(wrapper) + + result = wrapper.verify_silent() + assert result is False + + def test_invalid_expected_hash_raises_error(self): + """Test invalid expected hash raises error at construction.""" + stream = io.BytesIO(TEST_CONTENT_HELLO) + + with pytest.raises(InvalidHashFormatError): + VerifyingStreamWrapper(stream, INVALID_HASH_TOO_SHORT) + + def test_on_failure_callback(self): + """Test on_failure callback is called on verification failure.""" + stream = io.BytesIO(TEST_CONTENT_HELLO) + wrong_hash = "a" * 64 + + callback_called = [] + + def on_failure(error): + callback_called.append(error) + + wrapper = VerifyingStreamWrapper(stream, wrong_hash, on_failure=on_failure) + + list(wrapper) + + with pytest.raises(ChecksumMismatchError): + wrapper.verify() + + assert len(callback_called) == 1 + assert isinstance(callback_called[0], ChecksumMismatchError) + + def test_get_actual_hash_after_iteration(self): + """Test get_actual_hash returns hash after iteration.""" + stream = io.BytesIO(TEST_CONTENT_HELLO) + wrapper = VerifyingStreamWrapper(stream, TEST_HASH_HELLO) + + # Before iteration + assert wrapper.get_actual_hash() is None + + list(wrapper) + + # After iteration + assert wrapper.get_actual_hash() == TEST_HASH_HELLO + + def test_includes_context_in_error(self): + """Test error includes artifact_id and s3_key.""" + stream = io.BytesIO(TEST_CONTENT_HELLO) + wrong_hash = "a" * 64 + wrapper = VerifyingStreamWrapper( + stream, + wrong_hash, + artifact_id="test-artifact", + s3_key="test/key", + ) + + list(wrapper) + + with pytest.raises(ChecksumMismatchError) as exc_info: + wrapper.verify() + + error = exc_info.value + assert error.artifact_id == "test-artifact" + assert error.s3_key == "test/key" + + +# ============================================================================= +# Unit Tests - ChecksumMismatchError +# ============================================================================= + + +class TestChecksumMismatchError: + """Tests for ChecksumMismatchError class.""" + + def test_to_dict(self): + """Test to_dict returns proper dictionary.""" + error = ChecksumMismatchError( + expected="a" * 64, + actual="b" * 64, + artifact_id="test-123", + s3_key="test/key", + size=1024, + ) + + result = error.to_dict() + + assert result["error"] == "checksum_mismatch" + assert result["expected"] == "a" * 64 + assert result["actual"] == "b" * 64 + assert result["artifact_id"] == "test-123" + assert result["s3_key"] == "test/key" + assert result["size"] == 1024 + + def test_message_format(self): + """Test error message format.""" + error = ChecksumMismatchError( + expected="a" * 64, + actual="b" * 64, + ) + + assert "verification failed" in str(error).lower() + assert "expected" in str(error).lower() + + def test_custom_message(self): + """Test custom message is used.""" + error = ChecksumMismatchError( + expected="a" * 64, + actual="b" * 64, + message="Custom error message", + ) + + assert str(error) == "Custom error message" + + +# ============================================================================= +# Corruption Simulation Tests +# ============================================================================= + + +class TestCorruptionDetection: + """Tests for detecting corrupted content.""" + + def test_detect_truncated_content(self): + """Test detection of truncated content.""" + original = TEST_CONTENT_HELLO + truncated = original[:-1] # Remove last byte + + original_hash = compute_sha256(original) + truncated_hash = compute_sha256(truncated) + + assert original_hash != truncated_hash + assert verify_checksum(truncated, original_hash) is False + + def test_detect_extra_bytes(self): + """Test detection of content with extra bytes.""" + original = TEST_CONTENT_HELLO + extended = original + b"\x00" # Add null byte + + original_hash = compute_sha256(original) + + assert verify_checksum(extended, original_hash) is False + + def test_detect_single_bit_flip(self): + """Test detection of single bit flip.""" + original = TEST_CONTENT_HELLO + # Flip first bit of first byte + corrupted = bytes([original[0] ^ 0x01]) + original[1:] + + original_hash = compute_sha256(original) + + assert verify_checksum(corrupted, original_hash) is False + + def test_detect_wrong_content(self): + """Test detection of completely different content.""" + original = TEST_CONTENT_HELLO + different = b"Something completely different" + + original_hash = compute_sha256(original) + + assert verify_checksum(different, original_hash) is False + + def test_detect_empty_vs_nonempty(self): + """Test detection of empty content vs non-empty.""" + original = TEST_CONTENT_HELLO + empty = b"" + + original_hash = compute_sha256(original) + + assert verify_checksum(empty, original_hash) is False + + def test_streaming_detection_of_corruption(self): + """Test VerifyingStreamWrapper detects corruption.""" + original = b"Original content that will be corrupted" + original_hash = compute_sha256(original) + + # Corrupt the content + corrupted = b"Corrupted content that is different" + stream = io.BytesIO(corrupted) + + wrapper = VerifyingStreamWrapper(stream, original_hash) + list(wrapper) # Consume + + with pytest.raises(ChecksumMismatchError): + wrapper.verify() + + +# ============================================================================= +# Edge Case Tests +# ============================================================================= + + +class TestEdgeCases: + """Tests for edge cases and boundary conditions.""" + + def test_null_bytes_in_content(self): + """Test content with null bytes.""" + content = b"\x00\x00\x00" + hash_result = compute_sha256(content) + + assert verify_checksum(content, hash_result) is True + + def test_whitespace_only_content(self): + """Test content with only whitespace.""" + content = b" \t\n\r " + hash_result = compute_sha256(content) + + assert verify_checksum(content, hash_result) is True + + def test_large_content_streaming(self): + """Test streaming verification of large content.""" + # 1MB of content + large_content = b"x" * (1024 * 1024) + expected_hash = compute_sha256(large_content) + + stream = io.BytesIO(large_content) + wrapper = VerifyingStreamWrapper(stream, expected_hash) + + # Consume and verify + chunks = list(wrapper) + assert wrapper.verify() is True + assert b"".join(chunks) == large_content + + def test_unicode_bytes_content(self): + """Test content with unicode bytes.""" + content = "Hello, δΈ–η•Œ! 🌍".encode("utf-8") + hash_result = compute_sha256(content) + + assert verify_checksum(content, hash_result) is True + + def test_maximum_chunk_size_larger_than_content(self): + """Test chunk size larger than content.""" + content = b"small" + stream = io.BytesIO(content) + wrapper = HashingStreamWrapper(stream, chunk_size=1024 * 1024) + + chunks = list(wrapper) + + assert len(chunks) == 1 + assert chunks[0] == content + assert wrapper.get_hash() == compute_sha256(content) diff --git a/backend/tests/test_download_verification.py b/backend/tests/test_download_verification.py new file mode 100644 index 0000000..ddec899 --- /dev/null +++ b/backend/tests/test_download_verification.py @@ -0,0 +1,460 @@ +""" +Integration tests for download verification API endpoints. + +These tests verify: +- Checksum headers in download responses +- Pre-verification mode +- Streaming verification mode +- HEAD request headers +- Verification failure handling +""" + +import pytest +import hashlib +import base64 +import io + + +# ============================================================================= +# Test Fixtures +# ============================================================================= + + +@pytest.fixture +def upload_test_file(integration_client): + """ + Factory fixture to upload a test file and return its artifact ID. + + Usage: + artifact_id = upload_test_file(project, package, content, tag="v1.0") + """ + + def _upload(project_name: str, package_name: str, content: bytes, tag: str = None): + files = { + "file": ("test-file.bin", io.BytesIO(content), "application/octet-stream") + } + data = {} + if tag: + data["tag"] = tag + + response = integration_client.post( + f"/api/v1/project/{project_name}/{package_name}/upload", + files=files, + data=data, + ) + assert response.status_code == 200, f"Upload failed: {response.text}" + return response.json()["artifact_id"] + + return _upload + + +# ============================================================================= +# Integration Tests - Download Headers +# ============================================================================= + + +class TestDownloadChecksumHeaders: + """Tests for checksum headers in download responses.""" + + @pytest.mark.integration + def test_download_includes_sha256_header( + self, integration_client, test_package, upload_test_file + ): + """Test download response includes X-Checksum-SHA256 header.""" + project_name, package_name = test_package + content = b"Content for SHA256 header test" + + # Upload file + artifact_id = upload_test_file( + project_name, package_name, content, tag="sha256-header-test" + ) + + # Download with proxy mode + response = integration_client.get( + f"/api/v1/project/{project_name}/{package_name}/+/sha256-header-test", + params={"mode": "proxy"}, + ) + + assert response.status_code == 200 + assert "X-Checksum-SHA256" in response.headers + assert response.headers["X-Checksum-SHA256"] == artifact_id + + @pytest.mark.integration + def test_download_includes_etag_header( + self, integration_client, test_package, upload_test_file + ): + """Test download response includes ETag header.""" + project_name, package_name = test_package + content = b"Content for ETag header test" + + artifact_id = upload_test_file( + project_name, package_name, content, tag="etag-test" + ) + + response = integration_client.get( + f"/api/v1/project/{project_name}/{package_name}/+/etag-test", + params={"mode": "proxy"}, + ) + + assert response.status_code == 200 + assert "ETag" in response.headers + # ETag should be quoted artifact ID + assert response.headers["ETag"] == f'"{artifact_id}"' + + @pytest.mark.integration + def test_download_includes_digest_header( + self, integration_client, test_package, upload_test_file + ): + """Test download response includes RFC 3230 Digest header.""" + project_name, package_name = test_package + content = b"Content for Digest header test" + sha256 = hashlib.sha256(content).hexdigest() + + upload_test_file(project_name, package_name, content, tag="digest-test") + + response = integration_client.get( + f"/api/v1/project/{project_name}/{package_name}/+/digest-test", + params={"mode": "proxy"}, + ) + + assert response.status_code == 200 + assert "Digest" in response.headers + + # Verify Digest format: sha-256= + digest = response.headers["Digest"] + assert digest.startswith("sha-256=") + + # Verify base64 content matches + b64_hash = digest.split("=", 1)[1] + decoded = base64.b64decode(b64_hash) + assert decoded == bytes.fromhex(sha256) + + @pytest.mark.integration + def test_download_includes_content_length_header( + self, integration_client, test_package, upload_test_file + ): + """Test download response includes X-Content-Length header.""" + project_name, package_name = test_package + content = b"Content for X-Content-Length test" + + upload_test_file(project_name, package_name, content, tag="content-length-test") + + response = integration_client.get( + f"/api/v1/project/{project_name}/{package_name}/+/content-length-test", + params={"mode": "proxy"}, + ) + + assert response.status_code == 200 + assert "X-Content-Length" in response.headers + assert response.headers["X-Content-Length"] == str(len(content)) + + @pytest.mark.integration + def test_download_includes_verified_header_false( + self, integration_client, test_package, upload_test_file + ): + """Test download without verification has X-Verified: false.""" + project_name, package_name = test_package + content = b"Content for X-Verified false test" + + upload_test_file(project_name, package_name, content, tag="verified-false-test") + + response = integration_client.get( + f"/api/v1/project/{project_name}/{package_name}/+/verified-false-test", + params={"mode": "proxy", "verify": "false"}, + ) + + assert response.status_code == 200 + assert "X-Verified" in response.headers + assert response.headers["X-Verified"] == "false" + + +# ============================================================================= +# Integration Tests - Pre-Verification Mode +# ============================================================================= + + +class TestPreVerificationMode: + """Tests for pre-verification download mode.""" + + @pytest.mark.integration + def test_pre_verify_success( + self, integration_client, test_package, upload_test_file + ): + """Test pre-verification mode succeeds for valid content.""" + project_name, package_name = test_package + content = b"Content for pre-verification success test" + + upload_test_file(project_name, package_name, content, tag="pre-verify-success") + + response = integration_client.get( + f"/api/v1/project/{project_name}/{package_name}/+/pre-verify-success", + params={"mode": "proxy", "verify": "true", "verify_mode": "pre"}, + ) + + assert response.status_code == 200 + assert response.content == content + assert "X-Verified" in response.headers + assert response.headers["X-Verified"] == "true" + + @pytest.mark.integration + def test_pre_verify_returns_complete_content( + self, integration_client, test_package, upload_test_file + ): + """Test pre-verification returns complete content.""" + project_name, package_name = test_package + # Use binary content to verify no corruption + content = bytes(range(256)) * 10 # 2560 bytes of all byte values + + upload_test_file(project_name, package_name, content, tag="pre-verify-content") + + response = integration_client.get( + f"/api/v1/project/{project_name}/{package_name}/+/pre-verify-content", + params={"mode": "proxy", "verify": "true", "verify_mode": "pre"}, + ) + + assert response.status_code == 200 + assert response.content == content + + +# ============================================================================= +# Integration Tests - Streaming Verification Mode +# ============================================================================= + + +class TestStreamingVerificationMode: + """Tests for streaming verification download mode.""" + + @pytest.mark.integration + def test_stream_verify_success( + self, integration_client, test_package, upload_test_file + ): + """Test streaming verification mode succeeds for valid content.""" + project_name, package_name = test_package + content = b"Content for streaming verification success test" + + upload_test_file( + project_name, package_name, content, tag="stream-verify-success" + ) + + response = integration_client.get( + f"/api/v1/project/{project_name}/{package_name}/+/stream-verify-success", + params={"mode": "proxy", "verify": "true", "verify_mode": "stream"}, + ) + + assert response.status_code == 200 + assert response.content == content + # X-Verified is "pending" for streaming mode (verified after transfer) + assert "X-Verified" in response.headers + + @pytest.mark.integration + def test_stream_verify_large_content( + self, integration_client, test_package, upload_test_file + ): + """Test streaming verification with larger content.""" + project_name, package_name = test_package + # 100KB of content + content = b"x" * (100 * 1024) + + upload_test_file(project_name, package_name, content, tag="stream-verify-large") + + response = integration_client.get( + f"/api/v1/project/{project_name}/{package_name}/+/stream-verify-large", + params={"mode": "proxy", "verify": "true", "verify_mode": "stream"}, + ) + + assert response.status_code == 200 + assert response.content == content + + +# ============================================================================= +# Integration Tests - HEAD Request Headers +# ============================================================================= + + +class TestHeadRequestHeaders: + """Tests for HEAD request checksum headers.""" + + @pytest.mark.integration + def test_head_includes_sha256_header( + self, integration_client, test_package, upload_test_file + ): + """Test HEAD request includes X-Checksum-SHA256 header.""" + project_name, package_name = test_package + content = b"Content for HEAD SHA256 test" + + artifact_id = upload_test_file( + project_name, package_name, content, tag="head-sha256-test" + ) + + response = integration_client.head( + f"/api/v1/project/{project_name}/{package_name}/+/head-sha256-test" + ) + + assert response.status_code == 200 + assert "X-Checksum-SHA256" in response.headers + assert response.headers["X-Checksum-SHA256"] == artifact_id + + @pytest.mark.integration + def test_head_includes_etag( + self, integration_client, test_package, upload_test_file + ): + """Test HEAD request includes ETag header.""" + project_name, package_name = test_package + content = b"Content for HEAD ETag test" + + artifact_id = upload_test_file( + project_name, package_name, content, tag="head-etag-test" + ) + + response = integration_client.head( + f"/api/v1/project/{project_name}/{package_name}/+/head-etag-test" + ) + + assert response.status_code == 200 + assert "ETag" in response.headers + assert response.headers["ETag"] == f'"{artifact_id}"' + + @pytest.mark.integration + def test_head_includes_digest( + self, integration_client, test_package, upload_test_file + ): + """Test HEAD request includes Digest header.""" + project_name, package_name = test_package + content = b"Content for HEAD Digest test" + + upload_test_file(project_name, package_name, content, tag="head-digest-test") + + response = integration_client.head( + f"/api/v1/project/{project_name}/{package_name}/+/head-digest-test" + ) + + assert response.status_code == 200 + assert "Digest" in response.headers + assert response.headers["Digest"].startswith("sha-256=") + + @pytest.mark.integration + def test_head_includes_content_length( + self, integration_client, test_package, upload_test_file + ): + """Test HEAD request includes X-Content-Length header.""" + project_name, package_name = test_package + content = b"Content for HEAD Content-Length test" + + upload_test_file(project_name, package_name, content, tag="head-length-test") + + response = integration_client.head( + f"/api/v1/project/{project_name}/{package_name}/+/head-length-test" + ) + + assert response.status_code == 200 + assert "X-Content-Length" in response.headers + assert response.headers["X-Content-Length"] == str(len(content)) + + @pytest.mark.integration + def test_head_no_body(self, integration_client, test_package, upload_test_file): + """Test HEAD request returns no body.""" + project_name, package_name = test_package + content = b"Content for HEAD no-body test" + + upload_test_file(project_name, package_name, content, tag="head-no-body-test") + + response = integration_client.head( + f"/api/v1/project/{project_name}/{package_name}/+/head-no-body-test" + ) + + assert response.status_code == 200 + assert response.content == b"" + + +# ============================================================================= +# Integration Tests - Range Requests +# ============================================================================= + + +class TestRangeRequestHeaders: + """Tests for range request handling with checksum headers.""" + + @pytest.mark.integration + def test_range_request_includes_checksum_headers( + self, integration_client, test_package, upload_test_file + ): + """Test range request includes checksum headers.""" + project_name, package_name = test_package + content = b"Content for range request checksum header test" + + upload_test_file(project_name, package_name, content, tag="range-checksum-test") + + response = integration_client.get( + f"/api/v1/project/{project_name}/{package_name}/+/range-checksum-test", + headers={"Range": "bytes=0-9"}, + params={"mode": "proxy"}, + ) + + assert response.status_code == 206 + assert "X-Checksum-SHA256" in response.headers + # Checksum is for the FULL file, not the range + assert len(response.headers["X-Checksum-SHA256"]) == 64 + + +# ============================================================================= +# Integration Tests - Client-Side Verification +# ============================================================================= + + +class TestClientSideVerification: + """Tests demonstrating client-side verification using headers.""" + + @pytest.mark.integration + def test_client_can_verify_downloaded_content( + self, integration_client, test_package, upload_test_file + ): + """Test client can verify downloaded content using header.""" + project_name, package_name = test_package + content = b"Content for client-side verification test" + + upload_test_file(project_name, package_name, content, tag="client-verify-test") + + response = integration_client.get( + f"/api/v1/project/{project_name}/{package_name}/+/client-verify-test", + params={"mode": "proxy"}, + ) + + assert response.status_code == 200 + + # Get expected hash from header + expected_hash = response.headers["X-Checksum-SHA256"] + + # Compute actual hash of downloaded content + actual_hash = hashlib.sha256(response.content).hexdigest() + + # Verify match + assert actual_hash == expected_hash + + @pytest.mark.integration + def test_client_can_verify_using_digest_header( + self, integration_client, test_package, upload_test_file + ): + """Test client can verify using RFC 3230 Digest header.""" + project_name, package_name = test_package + content = b"Content for Digest header verification" + + upload_test_file(project_name, package_name, content, tag="digest-verify-test") + + response = integration_client.get( + f"/api/v1/project/{project_name}/{package_name}/+/digest-verify-test", + params={"mode": "proxy"}, + ) + + assert response.status_code == 200 + + # Parse Digest header + digest_header = response.headers["Digest"] + assert digest_header.startswith("sha-256=") + b64_hash = digest_header.split("=", 1)[1] + expected_hash_bytes = base64.b64decode(b64_hash) + + # Compute actual hash of downloaded content + actual_hash_bytes = hashlib.sha256(response.content).digest() + + # Verify match + assert actual_hash_bytes == expected_hash_bytes