Merge branch 'feature/download-verification' into 'main'

Add download verification with SHA256 checksum support (#26, #27, #28, #29)

See merge request esv/bsf/bsf-integration/orchard/orchard-mvp!22
This commit is contained in:
Mondo Diaz
2026-01-07 13:36:46 -06:00
8 changed files with 2157 additions and 12 deletions

View File

@@ -7,6 +7,30 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased]
### Added
- Added download verification with `verify` and `verify_mode` query parameters (#26)
- `?verify=true&verify_mode=pre` - Pre-verification: verify before streaming (guaranteed no corrupt data)
- `?verify=true&verify_mode=stream` - Streaming verification: verify while streaming (logs error if mismatch)
- Added checksum response headers to all download endpoints (#27)
- `X-Checksum-SHA256` - SHA256 hash of the artifact
- `X-Content-Length` - File size in bytes
- `X-Checksum-MD5` - MD5 hash (if available)
- `ETag` - Artifact ID (SHA256)
- `Digest` - RFC 3230 format sha-256 hash (base64)
- `X-Verified` - Verification status (true/false/pending)
- Added `checksum.py` module with SHA256 utilities (#26)
- `compute_sha256()` and `compute_sha256_stream()` functions
- `HashingStreamWrapper` for incremental hash computation
- `VerifyingStreamWrapper` for stream verification
- `verify_checksum()` and `verify_checksum_strict()` functions
- `ChecksumMismatchError` exception with context
- Added `get_verified()` and `get_stream_verified()` methods to storage layer (#26)
- Added `logging_config.py` module with structured logging (#28)
- JSON logging format for production
- Request ID tracking via context variables
- Verification failure logging with full context
- Added `log_level` and `log_format` settings to configuration (#28)
- Added 62 unit tests for checksum utilities and verification (#29)
- Added 17 integration tests for download verification API (#29)
- Added global artifacts endpoint `GET /api/v1/artifacts` with project/package/tag/size/date filters (#18)
- Added global tags endpoint `GET /api/v1/tags` with project/package/search/date filters (#18)
- Added wildcard pattern matching (`*`) for tag filters across all endpoints (#18)

477
backend/app/checksum.py Normal file
View File

@@ -0,0 +1,477 @@
"""
Checksum utilities for download verification.
This module provides functions and classes for computing and verifying
SHA256 checksums during artifact downloads.
Key components:
- compute_sha256(): Compute SHA256 of bytes content
- compute_sha256_stream(): Compute SHA256 from an iterable stream
- HashingStreamWrapper: Wrapper that computes hash while streaming
- VerifyingStreamWrapper: Wrapper that verifies hash after streaming
- verify_checksum(): Verify content against expected hash
- ChecksumMismatchError: Exception for verification failures
"""
import hashlib
import logging
import re
import base64
from typing import (
Generator,
Optional,
Any,
Callable,
)
logger = logging.getLogger(__name__)
# Default chunk size for streaming operations (8KB)
DEFAULT_CHUNK_SIZE = 8 * 1024
# Regex pattern for valid SHA256 hash (64 hex characters)
SHA256_PATTERN = re.compile(r"^[a-fA-F0-9]{64}$")
class ChecksumError(Exception):
"""Base exception for checksum operations."""
pass
class ChecksumMismatchError(ChecksumError):
"""
Raised when computed checksum does not match expected checksum.
Attributes:
expected: The expected SHA256 hash
actual: The actual computed SHA256 hash
artifact_id: Optional artifact ID for context
s3_key: Optional S3 key for debugging
size: Optional file size
"""
def __init__(
self,
expected: str,
actual: str,
artifact_id: Optional[str] = None,
s3_key: Optional[str] = None,
size: Optional[int] = None,
message: Optional[str] = None,
):
self.expected = expected
self.actual = actual
self.artifact_id = artifact_id
self.s3_key = s3_key
self.size = size
if message:
self.message = message
else:
self.message = (
f"Checksum verification failed: "
f"expected {expected[:16]}..., got {actual[:16]}..."
)
super().__init__(self.message)
def to_dict(self) -> dict:
"""Convert to dictionary for logging/API responses."""
return {
"error": "checksum_mismatch",
"expected": self.expected,
"actual": self.actual,
"artifact_id": self.artifact_id,
"s3_key": self.s3_key,
"size": self.size,
"message": self.message,
}
class InvalidHashFormatError(ChecksumError):
"""Raised when a hash string is not valid SHA256 format."""
def __init__(self, hash_value: str):
self.hash_value = hash_value
message = f"Invalid SHA256 hash format: '{hash_value[:32]}...'"
super().__init__(message)
def is_valid_sha256(hash_value: str) -> bool:
"""
Check if a string is a valid SHA256 hash (64 hex characters).
Args:
hash_value: String to validate
Returns:
True if valid SHA256 format, False otherwise
"""
if not hash_value:
return False
return bool(SHA256_PATTERN.match(hash_value))
def compute_sha256(content: bytes) -> str:
"""
Compute SHA256 hash of bytes content.
Args:
content: Bytes content to hash
Returns:
Lowercase hexadecimal SHA256 hash (64 characters)
Raises:
ChecksumError: If hash computation fails
"""
if content is None:
raise ChecksumError("Cannot compute hash of None content")
try:
return hashlib.sha256(content).hexdigest().lower()
except Exception as e:
raise ChecksumError(f"Hash computation failed: {e}") from e
def compute_sha256_stream(
stream: Any,
chunk_size: int = DEFAULT_CHUNK_SIZE,
) -> str:
"""
Compute SHA256 hash from a stream or file-like object.
Reads the stream in chunks to minimize memory usage for large files.
Args:
stream: Iterator yielding bytes or file-like object with read()
chunk_size: Size of chunks to read (default 8KB)
Returns:
Lowercase hexadecimal SHA256 hash (64 characters)
Raises:
ChecksumError: If hash computation fails
"""
try:
hasher = hashlib.sha256()
# Handle file-like objects with read()
if hasattr(stream, "read"):
while True:
chunk = stream.read(chunk_size)
if not chunk:
break
hasher.update(chunk)
else:
# Handle iterators
for chunk in stream:
if chunk:
hasher.update(chunk)
return hasher.hexdigest().lower()
except Exception as e:
raise ChecksumError(f"Stream hash computation failed: {e}") from e
def verify_checksum(content: bytes, expected: str) -> bool:
"""
Verify that content matches expected SHA256 hash.
Args:
content: Bytes content to verify
expected: Expected SHA256 hash (case-insensitive)
Returns:
True if hash matches, False otherwise
Raises:
InvalidHashFormatError: If expected hash is not valid format
ChecksumError: If hash computation fails
"""
if not is_valid_sha256(expected):
raise InvalidHashFormatError(expected)
actual = compute_sha256(content)
return actual == expected.lower()
def verify_checksum_strict(
content: bytes,
expected: str,
artifact_id: Optional[str] = None,
s3_key: Optional[str] = None,
) -> None:
"""
Verify content matches expected hash, raising exception on mismatch.
Args:
content: Bytes content to verify
expected: Expected SHA256 hash (case-insensitive)
artifact_id: Optional artifact ID for error context
s3_key: Optional S3 key for error context
Raises:
InvalidHashFormatError: If expected hash is not valid format
ChecksumMismatchError: If verification fails
ChecksumError: If hash computation fails
"""
if not is_valid_sha256(expected):
raise InvalidHashFormatError(expected)
actual = compute_sha256(content)
if actual != expected.lower():
raise ChecksumMismatchError(
expected=expected.lower(),
actual=actual,
artifact_id=artifact_id,
s3_key=s3_key,
size=len(content),
)
def sha256_to_base64(hex_hash: str) -> str:
"""
Convert SHA256 hex string to base64 encoding (for RFC 3230 Digest header).
Args:
hex_hash: SHA256 hash as 64-character hex string
Returns:
Base64-encoded hash string
"""
if not is_valid_sha256(hex_hash):
raise InvalidHashFormatError(hex_hash)
hash_bytes = bytes.fromhex(hex_hash)
return base64.b64encode(hash_bytes).decode("ascii")
class HashingStreamWrapper:
"""
Wrapper that computes SHA256 hash incrementally as chunks are read.
This allows computing the hash while streaming content to a client,
without buffering the entire content in memory.
Usage:
wrapper = HashingStreamWrapper(stream)
for chunk in wrapper:
send_to_client(chunk)
final_hash = wrapper.get_hash()
Attributes:
chunk_size: Size of chunks to yield
bytes_read: Total bytes processed so far
"""
def __init__(
self,
stream: Any,
chunk_size: int = DEFAULT_CHUNK_SIZE,
):
"""
Initialize the hashing stream wrapper.
Args:
stream: Source stream (iterator, file-like, or S3 StreamingBody)
chunk_size: Size of chunks to yield (default 8KB)
"""
self._stream = stream
self._hasher = hashlib.sha256()
self._chunk_size = chunk_size
self._bytes_read = 0
self._finalized = False
self._final_hash: Optional[str] = None
@property
def bytes_read(self) -> int:
"""Total bytes read so far."""
return self._bytes_read
@property
def chunk_size(self) -> int:
"""Chunk size for reading."""
return self._chunk_size
def __iter__(self) -> Generator[bytes, None, None]:
"""Iterate over chunks, computing hash as we go."""
# Handle S3 StreamingBody (has iter_chunks)
if hasattr(self._stream, "iter_chunks"):
for chunk in self._stream.iter_chunks(chunk_size=self._chunk_size):
if chunk:
self._hasher.update(chunk)
self._bytes_read += len(chunk)
yield chunk
# Handle file-like objects with read()
elif hasattr(self._stream, "read"):
while True:
chunk = self._stream.read(self._chunk_size)
if not chunk:
break
self._hasher.update(chunk)
self._bytes_read += len(chunk)
yield chunk
# Handle iterators
else:
for chunk in self._stream:
if chunk:
self._hasher.update(chunk)
self._bytes_read += len(chunk)
yield chunk
self._finalized = True
self._final_hash = self._hasher.hexdigest().lower()
def get_hash(self) -> str:
"""
Get the computed SHA256 hash.
If stream hasn't been fully consumed, consumes remaining chunks.
Returns:
Lowercase hexadecimal SHA256 hash
"""
if not self._finalized:
# Consume remaining stream
for _ in self:
pass
return self._final_hash or self._hasher.hexdigest().lower()
def get_hash_if_complete(self) -> Optional[str]:
"""
Get hash only if stream has been fully consumed.
Returns:
Hash if complete, None otherwise
"""
if self._finalized:
return self._final_hash
return None
class VerifyingStreamWrapper:
"""
Wrapper that yields chunks and verifies hash after streaming completes.
IMPORTANT: Because HTTP streams cannot be "un-sent", if verification
fails after streaming, the client has already received potentially
corrupt data. This wrapper logs an error but cannot prevent delivery.
For guaranteed verification before delivery, use pre-verification mode
which buffers the entire content first.
Usage:
wrapper = VerifyingStreamWrapper(stream, expected_hash)
for chunk in wrapper:
send_to_client(chunk)
wrapper.verify() # Raises ChecksumMismatchError if failed
"""
def __init__(
self,
stream: Any,
expected_hash: str,
artifact_id: Optional[str] = None,
s3_key: Optional[str] = None,
chunk_size: int = DEFAULT_CHUNK_SIZE,
on_failure: Optional[Callable[[Any], None]] = None,
):
"""
Initialize the verifying stream wrapper.
Args:
stream: Source stream
expected_hash: Expected SHA256 hash to verify against
artifact_id: Optional artifact ID for error context
s3_key: Optional S3 key for error context
chunk_size: Size of chunks to yield
on_failure: Optional callback called on verification failure
"""
if not is_valid_sha256(expected_hash):
raise InvalidHashFormatError(expected_hash)
self._hashing_wrapper = HashingStreamWrapper(stream, chunk_size)
self._expected_hash = expected_hash.lower()
self._artifact_id = artifact_id
self._s3_key = s3_key
self._on_failure = on_failure
self._verified: Optional[bool] = None
@property
def bytes_read(self) -> int:
"""Total bytes read so far."""
return self._hashing_wrapper.bytes_read
@property
def is_verified(self) -> Optional[bool]:
"""
Verification status.
Returns:
True if verified successfully, False if failed, None if not yet complete
"""
return self._verified
def __iter__(self) -> Generator[bytes, None, None]:
"""Iterate over chunks."""
yield from self._hashing_wrapper
def verify(self) -> bool:
"""
Verify the hash after stream is complete.
Must be called after fully consuming the iterator.
Returns:
True if verification passed
Raises:
ChecksumMismatchError: If verification failed
"""
actual_hash = self._hashing_wrapper.get_hash()
if actual_hash == self._expected_hash:
self._verified = True
logger.debug(
f"Verification passed for {self._artifact_id or 'unknown'}: {actual_hash[:16]}..."
)
return True
self._verified = False
error = ChecksumMismatchError(
expected=self._expected_hash,
actual=actual_hash,
artifact_id=self._artifact_id,
s3_key=self._s3_key,
size=self._hashing_wrapper.bytes_read,
)
# Log the failure
logger.error(f"Checksum verification FAILED after streaming: {error.to_dict()}")
# Call failure callback if provided
if self._on_failure:
try:
self._on_failure(error)
except Exception as e:
logger.warning(f"Verification failure callback raised exception: {e}")
raise error
def verify_silent(self) -> bool:
"""
Verify the hash without raising exception.
Returns:
True if verification passed, False otherwise
"""
try:
return self.verify()
except ChecksumMismatchError:
return False
def get_actual_hash(self) -> Optional[str]:
"""Get the actual computed hash (only available after iteration)."""
return self._hashing_wrapper.get_hash_if_complete()

View File

@@ -48,6 +48,10 @@ class Settings(BaseSettings):
3600 # Presigned URL expiry in seconds (default: 1 hour)
)
# Logging settings
log_level: str = "INFO" # DEBUG, INFO, WARNING, ERROR, CRITICAL
log_format: str = "auto" # "json", "standard", or "auto" (json in production)
@property
def database_url(self) -> str:
sslmode = f"?sslmode={self.database_sslmode}" if self.database_sslmode else ""

View File

@@ -0,0 +1,254 @@
"""
Structured logging configuration for Orchard.
This module provides:
- Structured JSON logging for production environments
- Request tracing via X-Request-ID header
- Verification failure logging with context
- Configurable log levels via environment
Usage:
from app.logging_config import setup_logging, get_request_id
setup_logging() # Call once at app startup
request_id = get_request_id() # Get current request's ID
"""
import logging
import json
import sys
import uuid
from datetime import datetime, timezone
from typing import Optional, Any, Dict
from contextvars import ContextVar
from .config import get_settings
# Context variable for request ID (thread-safe)
_request_id_var: ContextVar[Optional[str]] = ContextVar("request_id", default=None)
def get_request_id() -> Optional[str]:
"""Get the current request's ID from context."""
return _request_id_var.get()
def set_request_id(request_id: Optional[str] = None) -> str:
"""
Set the request ID for the current context.
If no ID provided, generates a new UUID.
Returns the request ID that was set.
"""
if request_id is None:
request_id = str(uuid.uuid4())
_request_id_var.set(request_id)
return request_id
def clear_request_id():
"""Clear the request ID from context."""
_request_id_var.set(None)
class JSONFormatter(logging.Formatter):
"""
JSON log formatter for structured logging.
Output format:
{
"timestamp": "2025-01-01T00:00:00.000Z",
"level": "INFO",
"logger": "app.routes",
"message": "Request completed",
"request_id": "abc-123",
"extra": {...}
}
"""
def format(self, record: logging.LogRecord) -> str:
log_entry: Dict[str, Any] = {
"timestamp": datetime.now(timezone.utc).isoformat(),
"level": record.levelname,
"logger": record.name,
"message": record.getMessage(),
}
# Add request ID if available
request_id = get_request_id()
if request_id:
log_entry["request_id"] = request_id
# Add exception info if present
if record.exc_info:
log_entry["exception"] = self.formatException(record.exc_info)
# Add extra fields from record
extra_fields: Dict[str, Any] = {}
for key, value in record.__dict__.items():
if key not in (
"name",
"msg",
"args",
"created",
"filename",
"funcName",
"levelname",
"levelno",
"lineno",
"module",
"msecs",
"pathname",
"process",
"processName",
"relativeCreated",
"stack_info",
"exc_info",
"exc_text",
"thread",
"threadName",
"message",
"asctime",
):
try:
json.dumps(value) # Ensure serializable
extra_fields[key] = value
except (TypeError, ValueError):
extra_fields[key] = str(value)
if extra_fields:
log_entry["extra"] = extra_fields
return json.dumps(log_entry)
class StandardFormatter(logging.Formatter):
"""
Standard log formatter for development.
Output format:
[2025-01-01 00:00:00] INFO [app.routes] [req-abc123] Request completed
"""
def format(self, record: logging.LogRecord) -> str:
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
request_id = get_request_id()
req_str = f" [req-{request_id[:8]}]" if request_id else ""
base_msg = f"[{timestamp}] {record.levelname:5} [{record.name}]{req_str} {record.getMessage()}"
if record.exc_info:
base_msg += "\n" + self.formatException(record.exc_info)
return base_msg
def setup_logging(log_level: Optional[str] = None, json_format: Optional[bool] = None):
"""
Configure logging for the application.
Args:
log_level: Log level (DEBUG, INFO, WARNING, ERROR, CRITICAL).
Defaults to ORCHARD_LOG_LEVEL env var or INFO.
json_format: Use JSON format. Defaults to True in production.
"""
settings = get_settings()
# Determine log level
if log_level is None:
log_level = getattr(settings, "log_level", "INFO")
effective_level = log_level if log_level else "INFO"
level = getattr(logging, effective_level.upper(), logging.INFO)
# Determine format
if json_format is None:
json_format = settings.is_production
# Create handler
handler = logging.StreamHandler(sys.stdout)
handler.setLevel(level)
# Set formatter
if json_format:
handler.setFormatter(JSONFormatter())
else:
handler.setFormatter(StandardFormatter())
# Configure root logger
root_logger = logging.getLogger()
root_logger.setLevel(level)
# Remove existing handlers
root_logger.handlers.clear()
root_logger.addHandler(handler)
# Configure specific loggers
for logger_name in ["app", "uvicorn", "uvicorn.access", "uvicorn.error"]:
logger = logging.getLogger(logger_name)
logger.setLevel(level)
logger.handlers.clear()
logger.addHandler(handler)
logger.propagate = False
# Quiet down noisy loggers
logging.getLogger("botocore").setLevel(logging.WARNING)
logging.getLogger("boto3").setLevel(logging.WARNING)
logging.getLogger("urllib3").setLevel(logging.WARNING)
def log_verification_failure(
logger: logging.Logger,
expected_hash: str,
actual_hash: str,
artifact_id: Optional[str] = None,
s3_key: Optional[str] = None,
project: Optional[str] = None,
package: Optional[str] = None,
size: Optional[int] = None,
user_id: Optional[str] = None,
source_ip: Optional[str] = None,
verification_mode: Optional[str] = None,
):
"""
Log a verification failure with full context.
This creates a structured log entry with all relevant details
for debugging and alerting.
"""
logger.error(
"Checksum verification failed",
extra={
"event": "verification_failure",
"expected_hash": expected_hash,
"actual_hash": actual_hash,
"artifact_id": artifact_id,
"s3_key": s3_key,
"project": project,
"package": package,
"size": size,
"user_id": user_id,
"source_ip": source_ip,
"verification_mode": verification_mode,
"hash_match": expected_hash == actual_hash,
},
)
def log_verification_success(
logger: logging.Logger,
artifact_id: str,
size: Optional[int] = None,
verification_mode: Optional[str] = None,
duration_ms: Optional[float] = None,
):
"""Log a successful verification."""
logger.info(
f"Verification passed for artifact {artifact_id[:16]}...",
extra={
"event": "verification_success",
"artifact_id": artifact_id,
"size": size,
"verification_mode": verification_mode,
"duration_ms": duration_ms,
},
)

View File

@@ -97,6 +97,11 @@ from .schemas import (
)
from .metadata import extract_metadata
from .config import get_settings
from .checksum import (
ChecksumMismatchError,
VerifyingStreamWrapper,
sha256_to_base64,
)
router = APIRouter()
@@ -1777,7 +1782,7 @@ def _resolve_artifact_ref(
return artifact
# Download artifact with range request support and download modes
# Download artifact with range request support, download modes, and verification
@router.get("/api/v1/project/{project_name}/{package_name}/+/{ref}")
def download_artifact(
project_name: str,
@@ -1791,7 +1796,34 @@ def download_artifact(
default=None,
description="Download mode: proxy (stream through backend), redirect (302 to presigned URL), presigned (return JSON with URL)",
),
verify: bool = Query(
default=False,
description="Enable checksum verification during download",
),
verify_mode: Optional[Literal["stream", "pre"]] = Query(
default="stream",
description="Verification mode: 'stream' (verify after streaming, logs error if mismatch), 'pre' (verify before streaming, returns 500 if mismatch)",
),
):
"""
Download an artifact by reference (tag name, artifact:hash, tag:name).
Verification modes:
- verify=false (default): No verification, maximum performance
- verify=true&verify_mode=stream: Compute hash while streaming, verify after completion.
If mismatch, logs error but content already sent.
- verify=true&verify_mode=pre: Download and verify BEFORE streaming to client.
Higher latency but guarantees no corrupt data sent.
Response headers always include:
- X-Checksum-SHA256: The expected SHA256 hash
- X-Content-Length: File size in bytes
- ETag: Artifact ID (SHA256)
- Digest: RFC 3230 format sha-256 hash
When verify=true:
- X-Verified: 'true' if verified, 'false' if verification failed
"""
settings = get_settings()
# Get project and package
@@ -1831,6 +1863,23 @@ def download_artifact(
)
db.commit()
# Build common checksum headers (always included)
checksum_headers = {
"X-Checksum-SHA256": artifact.id,
"X-Content-Length": str(artifact.size),
"ETag": f'"{artifact.id}"',
}
# Add RFC 3230 Digest header
try:
digest_base64 = sha256_to_base64(artifact.id)
checksum_headers["Digest"] = f"sha-256={digest_base64}"
except Exception:
pass # Skip if conversion fails
# Add MD5 checksum if available
if artifact.checksum_md5:
checksum_headers["X-Checksum-MD5"] = artifact.checksum_md5
# Determine download mode (query param overrides server default)
download_mode = mode or settings.download_mode
@@ -1867,7 +1916,7 @@ def download_artifact(
return RedirectResponse(url=presigned_url, status_code=302)
# Proxy mode (default fallback) - stream through backend
# Handle range requests
# Handle range requests (verification not supported for partial downloads)
if range:
stream, content_length, content_range = storage.get_stream(
artifact.s3_key, range
@@ -1877,9 +1926,11 @@ def download_artifact(
"Content-Disposition": f'attachment; filename="{filename}"',
"Accept-Ranges": "bytes",
"Content-Length": str(content_length),
**checksum_headers,
}
if content_range:
headers["Content-Range"] = content_range
# Note: X-Verified not set for range requests (cannot verify partial content)
return StreamingResponse(
stream,
@@ -1888,16 +1939,88 @@ def download_artifact(
headers=headers,
)
# Full download
# Full download with optional verification
base_headers = {
"Content-Disposition": f'attachment; filename="{filename}"',
"Accept-Ranges": "bytes",
**checksum_headers,
}
# Pre-verification mode: verify before streaming
if verify and verify_mode == "pre":
try:
content = storage.get_verified(artifact.s3_key, artifact.id)
return Response(
content=content,
media_type=artifact.content_type or "application/octet-stream",
headers={
**base_headers,
"Content-Length": str(len(content)),
"X-Verified": "true",
},
)
except ChecksumMismatchError as e:
logger.error(
f"Pre-verification failed for artifact {artifact.id[:16]}...: {e.to_dict()}"
)
raise HTTPException(
status_code=500,
detail={
"error": "checksum_verification_failed",
"message": "Downloaded content does not match expected checksum",
"expected": e.expected,
"actual": e.actual,
"artifact_id": artifact.id,
},
)
# Streaming verification mode: verify while/after streaming
if verify and verify_mode == "stream":
verifying_wrapper, content_length, _ = storage.get_stream_verified(
artifact.s3_key, artifact.id
)
def verified_stream():
"""Generator that yields chunks and verifies after completion."""
try:
for chunk in verifying_wrapper:
yield chunk
# After all chunks yielded, verify
try:
verifying_wrapper.verify()
logger.info(
f"Streaming verification passed for artifact {artifact.id[:16]}..."
)
except ChecksumMismatchError as e:
# Content already sent - log error but cannot reject
logger.error(
f"Streaming verification FAILED for artifact {artifact.id[:16]}...: "
f"expected {e.expected[:16]}..., got {e.actual[:16]}..."
)
except Exception as e:
logger.error(f"Error during streaming download: {e}")
raise
return StreamingResponse(
verified_stream(),
media_type=artifact.content_type or "application/octet-stream",
headers={
**base_headers,
"Content-Length": str(content_length),
"X-Verified": "pending", # Verification happens after streaming
},
)
# No verification - direct streaming
stream, content_length, _ = storage.get_stream(artifact.s3_key)
return StreamingResponse(
stream,
media_type=artifact.content_type or "application/octet-stream",
headers={
"Content-Disposition": f'attachment; filename="{filename}"',
"Accept-Ranges": "bytes",
**base_headers,
"Content-Length": str(content_length),
"X-Verified": "false",
},
)
@@ -1975,6 +2098,11 @@ def head_artifact(
db: Session = Depends(get_db),
storage: S3Storage = Depends(get_storage),
):
"""
Get artifact metadata without downloading content.
Returns headers with checksum information for client-side verification.
"""
# Get project and package
project = db.query(Project).filter(Project.name == project_name).first()
if not project:
@@ -1995,15 +2123,32 @@ def head_artifact(
filename = sanitize_filename(artifact.original_name or f"{artifact.id}")
return Response(
content=b"",
media_type=artifact.content_type or "application/octet-stream",
headers={
# Build headers with checksum information
headers = {
"Content-Disposition": f'attachment; filename="{filename}"',
"Accept-Ranges": "bytes",
"Content-Length": str(artifact.size),
"X-Artifact-Id": artifact.id,
},
"X-Checksum-SHA256": artifact.id,
"X-Content-Length": str(artifact.size),
"ETag": f'"{artifact.id}"',
}
# Add RFC 3230 Digest header
try:
digest_base64 = sha256_to_base64(artifact.id)
headers["Digest"] = f"sha-256={digest_base64}"
except Exception:
pass # Skip if conversion fails
# Add MD5 checksum if available
if artifact.checksum_md5:
headers["X-Checksum-MD5"] = artifact.checksum_md5
return Response(
content=b"",
media_type=artifact.content_type or "application/octet-stream",
headers=headers,
)
@@ -2017,9 +2162,19 @@ def download_artifact_compat(
db: Session = Depends(get_db),
storage: S3Storage = Depends(get_storage),
range: Optional[str] = Header(None),
verify: bool = Query(default=False),
verify_mode: Optional[Literal["stream", "pre"]] = Query(default="stream"),
):
return download_artifact(
project_name, package_name, ref, request, db, storage, range
project_name,
package_name,
ref,
request,
db,
storage,
range,
verify=verify,
verify_mode=verify_mode,
)

View File

@@ -22,6 +22,13 @@ from botocore.exceptions import (
)
from .config import get_settings
from .checksum import (
ChecksumMismatchError,
HashingStreamWrapper,
VerifyingStreamWrapper,
compute_sha256,
is_valid_sha256,
)
settings = get_settings()
logger = logging.getLogger(__name__)
@@ -876,6 +883,95 @@ class S3Storage:
logger.error(f"Unexpected error during storage health check: {e}")
return False
def get_verified(self, s3_key: str, expected_hash: str) -> bytes:
"""
Download and verify content matches expected SHA256 hash.
This method downloads the entire content, computes its hash, and
verifies it matches the expected hash before returning.
Args:
s3_key: The S3 storage key of the file
expected_hash: Expected SHA256 hash (64 hex characters)
Returns:
File content as bytes (only if verification passes)
Raises:
ChecksumMismatchError: If computed hash doesn't match expected
ClientError: If S3 operation fails
"""
if not is_valid_sha256(expected_hash):
raise ValueError(f"Invalid SHA256 hash format: {expected_hash}")
content = self.get(s3_key)
actual_hash = compute_sha256(content)
if actual_hash != expected_hash.lower():
raise ChecksumMismatchError(
expected=expected_hash.lower(),
actual=actual_hash,
s3_key=s3_key,
size=len(content),
)
logger.debug(f"Verification passed for {s3_key}: {actual_hash[:16]}...")
return content
def get_stream_verified(
self,
s3_key: str,
expected_hash: str,
range_header: Optional[str] = None,
) -> Tuple[VerifyingStreamWrapper, int, Optional[str]]:
"""
Get a verifying stream wrapper for an object.
Returns a wrapper that computes the hash as chunks are read and
can verify after streaming completes. Note that verification happens
AFTER content has been streamed to the client.
IMPORTANT: For range requests, verification is not supported because
we cannot verify a partial download against the full file hash.
Args:
s3_key: The S3 storage key of the file
expected_hash: Expected SHA256 hash (64 hex characters)
range_header: Optional HTTP Range header (verification disabled if set)
Returns:
Tuple of (VerifyingStreamWrapper, content_length, content_range)
The wrapper has a verify() method to call after streaming.
Raises:
ValueError: If expected_hash is invalid format
ClientError: If S3 operation fails
"""
if not is_valid_sha256(expected_hash):
raise ValueError(f"Invalid SHA256 hash format: {expected_hash}")
# Get the S3 stream
stream, content_length, content_range = self.get_stream(s3_key, range_header)
# For range requests, we cannot verify (partial content)
# Return a HashingStreamWrapper that just tracks bytes without verification
if range_header or content_range:
logger.debug(
f"Range request for {s3_key} - verification disabled (partial content)"
)
# Return a basic hashing wrapper (caller should not verify)
hashing_wrapper = HashingStreamWrapper(stream)
return hashing_wrapper, content_length, content_range
# Create verifying wrapper
verifying_wrapper = VerifyingStreamWrapper(
stream=stream,
expected_hash=expected_hash,
s3_key=s3_key,
)
return verifying_wrapper, content_length, content_range
def verify_integrity(self, s3_key: str, expected_sha256: str) -> bool:
"""
Verify the integrity of a stored object by downloading and re-hashing.

View File

@@ -0,0 +1,675 @@
"""
Tests for checksum calculation, verification, and download verification.
This module tests:
- SHA256 hash computation (bytes and streams)
- HashingStreamWrapper incremental hashing
- VerifyingStreamWrapper with verification
- ChecksumMismatchError exception handling
- Download verification API endpoints
"""
import pytest
import hashlib
import io
from typing import Generator
from app.checksum import (
compute_sha256,
compute_sha256_stream,
verify_checksum,
verify_checksum_strict,
is_valid_sha256,
sha256_to_base64,
HashingStreamWrapper,
VerifyingStreamWrapper,
ChecksumMismatchError,
ChecksumError,
InvalidHashFormatError,
DEFAULT_CHUNK_SIZE,
)
# =============================================================================
# Test Data
# =============================================================================
# Known test vectors
TEST_CONTENT_HELLO = b"Hello, World!"
TEST_HASH_HELLO = "dffd6021bb2bd5b0af676290809ec3a53191dd81c7f70a4b28688a362182986f"
TEST_CONTENT_EMPTY = b""
TEST_HASH_EMPTY = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
TEST_CONTENT_BINARY = bytes(range(256))
TEST_HASH_BINARY = hashlib.sha256(TEST_CONTENT_BINARY).hexdigest()
# Invalid hashes for testing
INVALID_HASH_TOO_SHORT = "abcd1234"
INVALID_HASH_TOO_LONG = "a" * 65
INVALID_HASH_NON_HEX = "zzzz" + "a" * 60
INVALID_HASH_EMPTY = ""
# =============================================================================
# Unit Tests - SHA256 Computation
# =============================================================================
class TestComputeSHA256:
"""Tests for compute_sha256 function."""
def test_known_content_matches_expected_hash(self):
"""Test SHA256 of known content matches pre-computed hash."""
result = compute_sha256(TEST_CONTENT_HELLO)
assert result == TEST_HASH_HELLO
def test_returns_64_character_hex_string(self):
"""Test result is exactly 64 hex characters."""
result = compute_sha256(TEST_CONTENT_HELLO)
assert len(result) == 64
assert all(c in "0123456789abcdef" for c in result)
def test_returns_lowercase_hex(self):
"""Test result is lowercase."""
result = compute_sha256(TEST_CONTENT_HELLO)
assert result == result.lower()
def test_empty_content_returns_empty_hash(self):
"""Test empty bytes returns SHA256 of empty content."""
result = compute_sha256(TEST_CONTENT_EMPTY)
assert result == TEST_HASH_EMPTY
def test_deterministic_same_input_same_output(self):
"""Test same input always produces same output."""
content = b"test content for determinism"
result1 = compute_sha256(content)
result2 = compute_sha256(content)
assert result1 == result2
def test_different_content_different_hash(self):
"""Test different content produces different hash."""
hash1 = compute_sha256(b"content A")
hash2 = compute_sha256(b"content B")
assert hash1 != hash2
def test_single_bit_change_different_hash(self):
"""Test single bit change produces completely different hash."""
content1 = b"\x00" * 100
content2 = b"\x00" * 99 + b"\x01"
hash1 = compute_sha256(content1)
hash2 = compute_sha256(content2)
assert hash1 != hash2
def test_binary_content(self):
"""Test hashing binary content with all byte values."""
result = compute_sha256(TEST_CONTENT_BINARY)
assert result == TEST_HASH_BINARY
assert len(result) == 64
def test_large_content(self):
"""Test hashing larger content (1MB)."""
large_content = b"x" * (1024 * 1024)
result = compute_sha256(large_content)
expected = hashlib.sha256(large_content).hexdigest()
assert result == expected
def test_none_content_raises_error(self):
"""Test None content raises ChecksumError."""
with pytest.raises(ChecksumError, match="Cannot compute hash of None"):
compute_sha256(None)
class TestComputeSHA256Stream:
"""Tests for compute_sha256_stream function."""
def test_file_like_object(self):
"""Test hashing from file-like object."""
file_obj = io.BytesIO(TEST_CONTENT_HELLO)
result = compute_sha256_stream(file_obj)
assert result == TEST_HASH_HELLO
def test_iterator(self):
"""Test hashing from iterator of chunks."""
def chunk_iterator():
yield b"Hello, "
yield b"World!"
result = compute_sha256_stream(chunk_iterator())
assert result == TEST_HASH_HELLO
def test_various_chunk_sizes_same_result(self):
"""Test different chunk sizes produce same hash."""
content = b"x" * 10000
expected = hashlib.sha256(content).hexdigest()
for chunk_size in [1, 10, 100, 1000, 8192]:
file_obj = io.BytesIO(content)
result = compute_sha256_stream(file_obj, chunk_size=chunk_size)
assert result == expected, f"Failed for chunk_size={chunk_size}"
def test_single_byte_chunks(self):
"""Test with 1-byte chunks (edge case)."""
content = b"ABC"
file_obj = io.BytesIO(content)
result = compute_sha256_stream(file_obj, chunk_size=1)
expected = hashlib.sha256(content).hexdigest()
assert result == expected
def test_empty_stream(self):
"""Test empty stream returns empty content hash."""
file_obj = io.BytesIO(b"")
result = compute_sha256_stream(file_obj)
assert result == TEST_HASH_EMPTY
# =============================================================================
# Unit Tests - Hash Validation
# =============================================================================
class TestIsValidSHA256:
"""Tests for is_valid_sha256 function."""
def test_valid_hash_lowercase(self):
"""Test valid lowercase hash."""
assert is_valid_sha256(TEST_HASH_HELLO) is True
def test_valid_hash_uppercase(self):
"""Test valid uppercase hash."""
assert is_valid_sha256(TEST_HASH_HELLO.upper()) is True
def test_valid_hash_mixed_case(self):
"""Test valid mixed case hash."""
mixed = TEST_HASH_HELLO[:32].upper() + TEST_HASH_HELLO[32:].lower()
assert is_valid_sha256(mixed) is True
def test_invalid_too_short(self):
"""Test hash that's too short."""
assert is_valid_sha256(INVALID_HASH_TOO_SHORT) is False
def test_invalid_too_long(self):
"""Test hash that's too long."""
assert is_valid_sha256(INVALID_HASH_TOO_LONG) is False
def test_invalid_non_hex(self):
"""Test hash with non-hex characters."""
assert is_valid_sha256(INVALID_HASH_NON_HEX) is False
def test_invalid_empty(self):
"""Test empty string."""
assert is_valid_sha256(INVALID_HASH_EMPTY) is False
def test_invalid_none(self):
"""Test None value."""
assert is_valid_sha256(None) is False
class TestSHA256ToBase64:
"""Tests for sha256_to_base64 function."""
def test_converts_to_base64(self):
"""Test conversion to base64."""
result = sha256_to_base64(TEST_HASH_HELLO)
# Verify it's valid base64
import base64
decoded = base64.b64decode(result)
assert len(decoded) == 32 # SHA256 is 32 bytes
def test_invalid_hash_raises_error(self):
"""Test invalid hash raises InvalidHashFormatError."""
with pytest.raises(InvalidHashFormatError):
sha256_to_base64(INVALID_HASH_TOO_SHORT)
# =============================================================================
# Unit Tests - Verification Functions
# =============================================================================
class TestVerifyChecksum:
"""Tests for verify_checksum function."""
def test_matching_checksum_returns_true(self):
"""Test matching checksum returns True."""
result = verify_checksum(TEST_CONTENT_HELLO, TEST_HASH_HELLO)
assert result is True
def test_mismatched_checksum_returns_false(self):
"""Test mismatched checksum returns False."""
wrong_hash = "a" * 64
result = verify_checksum(TEST_CONTENT_HELLO, wrong_hash)
assert result is False
def test_case_insensitive_comparison(self):
"""Test comparison is case-insensitive."""
result = verify_checksum(TEST_CONTENT_HELLO, TEST_HASH_HELLO.upper())
assert result is True
def test_invalid_hash_format_raises_error(self):
"""Test invalid hash format raises error."""
with pytest.raises(InvalidHashFormatError):
verify_checksum(TEST_CONTENT_HELLO, INVALID_HASH_TOO_SHORT)
class TestVerifyChecksumStrict:
"""Tests for verify_checksum_strict function."""
def test_matching_checksum_returns_none(self):
"""Test matching checksum doesn't raise."""
# Should not raise
verify_checksum_strict(TEST_CONTENT_HELLO, TEST_HASH_HELLO)
def test_mismatched_checksum_raises_error(self):
"""Test mismatched checksum raises ChecksumMismatchError."""
wrong_hash = "a" * 64
with pytest.raises(ChecksumMismatchError) as exc_info:
verify_checksum_strict(TEST_CONTENT_HELLO, wrong_hash)
error = exc_info.value
assert error.expected == wrong_hash.lower()
assert error.actual == TEST_HASH_HELLO
assert error.size == len(TEST_CONTENT_HELLO)
def test_error_includes_context(self):
"""Test error includes artifact_id and s3_key context."""
wrong_hash = "a" * 64
with pytest.raises(ChecksumMismatchError) as exc_info:
verify_checksum_strict(
TEST_CONTENT_HELLO,
wrong_hash,
artifact_id="test-artifact-123",
s3_key="fruits/ab/cd/abcd1234...",
)
error = exc_info.value
assert error.artifact_id == "test-artifact-123"
assert error.s3_key == "fruits/ab/cd/abcd1234..."
# =============================================================================
# Unit Tests - HashingStreamWrapper
# =============================================================================
class TestHashingStreamWrapper:
"""Tests for HashingStreamWrapper class."""
def test_computes_correct_hash(self):
"""Test wrapper computes correct hash."""
stream = io.BytesIO(TEST_CONTENT_HELLO)
wrapper = HashingStreamWrapper(stream)
# Consume the stream
chunks = list(wrapper)
# Verify hash
assert wrapper.get_hash() == TEST_HASH_HELLO
def test_yields_correct_chunks(self):
"""Test wrapper yields all content."""
stream = io.BytesIO(TEST_CONTENT_HELLO)
wrapper = HashingStreamWrapper(stream)
chunks = list(wrapper)
content = b"".join(chunks)
assert content == TEST_CONTENT_HELLO
def test_tracks_bytes_read(self):
"""Test bytes_read property tracks correctly."""
stream = io.BytesIO(TEST_CONTENT_HELLO)
wrapper = HashingStreamWrapper(stream)
assert wrapper.bytes_read == 0
list(wrapper) # Consume
assert wrapper.bytes_read == len(TEST_CONTENT_HELLO)
def test_get_hash_before_iteration_consumes_stream(self):
"""Test get_hash() consumes stream if not already done."""
stream = io.BytesIO(TEST_CONTENT_HELLO)
wrapper = HashingStreamWrapper(stream)
# Call get_hash without iterating
hash_result = wrapper.get_hash()
assert hash_result == TEST_HASH_HELLO
assert wrapper.bytes_read == len(TEST_CONTENT_HELLO)
def test_get_hash_if_complete_before_iteration_returns_none(self):
"""Test get_hash_if_complete returns None before iteration."""
stream = io.BytesIO(TEST_CONTENT_HELLO)
wrapper = HashingStreamWrapper(stream)
assert wrapper.get_hash_if_complete() is None
def test_get_hash_if_complete_after_iteration_returns_hash(self):
"""Test get_hash_if_complete returns hash after iteration."""
stream = io.BytesIO(TEST_CONTENT_HELLO)
wrapper = HashingStreamWrapper(stream)
list(wrapper) # Consume
assert wrapper.get_hash_if_complete() == TEST_HASH_HELLO
def test_custom_chunk_size(self):
"""Test custom chunk size is respected."""
content = b"x" * 1000
stream = io.BytesIO(content)
wrapper = HashingStreamWrapper(stream, chunk_size=100)
chunks = list(wrapper)
# Each chunk should be at most 100 bytes
for chunk in chunks[:-1]: # All but last
assert len(chunk) == 100
# Total content should match
assert b"".join(chunks) == content
def test_iterator_interface(self):
"""Test wrapper supports iterator interface."""
stream = io.BytesIO(TEST_CONTENT_HELLO)
wrapper = HashingStreamWrapper(stream)
# Should be able to use for loop
result = b""
for chunk in wrapper:
result += chunk
assert result == TEST_CONTENT_HELLO
# =============================================================================
# Unit Tests - VerifyingStreamWrapper
# =============================================================================
class TestVerifyingStreamWrapper:
"""Tests for VerifyingStreamWrapper class."""
def test_verify_success(self):
"""Test verification succeeds for matching content."""
stream = io.BytesIO(TEST_CONTENT_HELLO)
wrapper = VerifyingStreamWrapper(stream, TEST_HASH_HELLO)
# Consume stream
list(wrapper)
# Verify should succeed
result = wrapper.verify()
assert result is True
assert wrapper.is_verified is True
def test_verify_failure_raises_error(self):
"""Test verification failure raises ChecksumMismatchError."""
stream = io.BytesIO(TEST_CONTENT_HELLO)
wrong_hash = "a" * 64
wrapper = VerifyingStreamWrapper(stream, wrong_hash)
# Consume stream
list(wrapper)
# Verify should fail
with pytest.raises(ChecksumMismatchError):
wrapper.verify()
assert wrapper.is_verified is False
def test_verify_silent_success(self):
"""Test verify_silent returns True on success."""
stream = io.BytesIO(TEST_CONTENT_HELLO)
wrapper = VerifyingStreamWrapper(stream, TEST_HASH_HELLO)
list(wrapper)
result = wrapper.verify_silent()
assert result is True
def test_verify_silent_failure(self):
"""Test verify_silent returns False on failure."""
stream = io.BytesIO(TEST_CONTENT_HELLO)
wrong_hash = "a" * 64
wrapper = VerifyingStreamWrapper(stream, wrong_hash)
list(wrapper)
result = wrapper.verify_silent()
assert result is False
def test_invalid_expected_hash_raises_error(self):
"""Test invalid expected hash raises error at construction."""
stream = io.BytesIO(TEST_CONTENT_HELLO)
with pytest.raises(InvalidHashFormatError):
VerifyingStreamWrapper(stream, INVALID_HASH_TOO_SHORT)
def test_on_failure_callback(self):
"""Test on_failure callback is called on verification failure."""
stream = io.BytesIO(TEST_CONTENT_HELLO)
wrong_hash = "a" * 64
callback_called = []
def on_failure(error):
callback_called.append(error)
wrapper = VerifyingStreamWrapper(stream, wrong_hash, on_failure=on_failure)
list(wrapper)
with pytest.raises(ChecksumMismatchError):
wrapper.verify()
assert len(callback_called) == 1
assert isinstance(callback_called[0], ChecksumMismatchError)
def test_get_actual_hash_after_iteration(self):
"""Test get_actual_hash returns hash after iteration."""
stream = io.BytesIO(TEST_CONTENT_HELLO)
wrapper = VerifyingStreamWrapper(stream, TEST_HASH_HELLO)
# Before iteration
assert wrapper.get_actual_hash() is None
list(wrapper)
# After iteration
assert wrapper.get_actual_hash() == TEST_HASH_HELLO
def test_includes_context_in_error(self):
"""Test error includes artifact_id and s3_key."""
stream = io.BytesIO(TEST_CONTENT_HELLO)
wrong_hash = "a" * 64
wrapper = VerifyingStreamWrapper(
stream,
wrong_hash,
artifact_id="test-artifact",
s3_key="test/key",
)
list(wrapper)
with pytest.raises(ChecksumMismatchError) as exc_info:
wrapper.verify()
error = exc_info.value
assert error.artifact_id == "test-artifact"
assert error.s3_key == "test/key"
# =============================================================================
# Unit Tests - ChecksumMismatchError
# =============================================================================
class TestChecksumMismatchError:
"""Tests for ChecksumMismatchError class."""
def test_to_dict(self):
"""Test to_dict returns proper dictionary."""
error = ChecksumMismatchError(
expected="a" * 64,
actual="b" * 64,
artifact_id="test-123",
s3_key="test/key",
size=1024,
)
result = error.to_dict()
assert result["error"] == "checksum_mismatch"
assert result["expected"] == "a" * 64
assert result["actual"] == "b" * 64
assert result["artifact_id"] == "test-123"
assert result["s3_key"] == "test/key"
assert result["size"] == 1024
def test_message_format(self):
"""Test error message format."""
error = ChecksumMismatchError(
expected="a" * 64,
actual="b" * 64,
)
assert "verification failed" in str(error).lower()
assert "expected" in str(error).lower()
def test_custom_message(self):
"""Test custom message is used."""
error = ChecksumMismatchError(
expected="a" * 64,
actual="b" * 64,
message="Custom error message",
)
assert str(error) == "Custom error message"
# =============================================================================
# Corruption Simulation Tests
# =============================================================================
class TestCorruptionDetection:
"""Tests for detecting corrupted content."""
def test_detect_truncated_content(self):
"""Test detection of truncated content."""
original = TEST_CONTENT_HELLO
truncated = original[:-1] # Remove last byte
original_hash = compute_sha256(original)
truncated_hash = compute_sha256(truncated)
assert original_hash != truncated_hash
assert verify_checksum(truncated, original_hash) is False
def test_detect_extra_bytes(self):
"""Test detection of content with extra bytes."""
original = TEST_CONTENT_HELLO
extended = original + b"\x00" # Add null byte
original_hash = compute_sha256(original)
assert verify_checksum(extended, original_hash) is False
def test_detect_single_bit_flip(self):
"""Test detection of single bit flip."""
original = TEST_CONTENT_HELLO
# Flip first bit of first byte
corrupted = bytes([original[0] ^ 0x01]) + original[1:]
original_hash = compute_sha256(original)
assert verify_checksum(corrupted, original_hash) is False
def test_detect_wrong_content(self):
"""Test detection of completely different content."""
original = TEST_CONTENT_HELLO
different = b"Something completely different"
original_hash = compute_sha256(original)
assert verify_checksum(different, original_hash) is False
def test_detect_empty_vs_nonempty(self):
"""Test detection of empty content vs non-empty."""
original = TEST_CONTENT_HELLO
empty = b""
original_hash = compute_sha256(original)
assert verify_checksum(empty, original_hash) is False
def test_streaming_detection_of_corruption(self):
"""Test VerifyingStreamWrapper detects corruption."""
original = b"Original content that will be corrupted"
original_hash = compute_sha256(original)
# Corrupt the content
corrupted = b"Corrupted content that is different"
stream = io.BytesIO(corrupted)
wrapper = VerifyingStreamWrapper(stream, original_hash)
list(wrapper) # Consume
with pytest.raises(ChecksumMismatchError):
wrapper.verify()
# =============================================================================
# Edge Case Tests
# =============================================================================
class TestEdgeCases:
"""Tests for edge cases and boundary conditions."""
def test_null_bytes_in_content(self):
"""Test content with null bytes."""
content = b"\x00\x00\x00"
hash_result = compute_sha256(content)
assert verify_checksum(content, hash_result) is True
def test_whitespace_only_content(self):
"""Test content with only whitespace."""
content = b" \t\n\r "
hash_result = compute_sha256(content)
assert verify_checksum(content, hash_result) is True
def test_large_content_streaming(self):
"""Test streaming verification of large content."""
# 1MB of content
large_content = b"x" * (1024 * 1024)
expected_hash = compute_sha256(large_content)
stream = io.BytesIO(large_content)
wrapper = VerifyingStreamWrapper(stream, expected_hash)
# Consume and verify
chunks = list(wrapper)
assert wrapper.verify() is True
assert b"".join(chunks) == large_content
def test_unicode_bytes_content(self):
"""Test content with unicode bytes."""
content = "Hello, 世界! 🌍".encode("utf-8")
hash_result = compute_sha256(content)
assert verify_checksum(content, hash_result) is True
def test_maximum_chunk_size_larger_than_content(self):
"""Test chunk size larger than content."""
content = b"small"
stream = io.BytesIO(content)
wrapper = HashingStreamWrapper(stream, chunk_size=1024 * 1024)
chunks = list(wrapper)
assert len(chunks) == 1
assert chunks[0] == content
assert wrapper.get_hash() == compute_sha256(content)

View File

@@ -0,0 +1,460 @@
"""
Integration tests for download verification API endpoints.
These tests verify:
- Checksum headers in download responses
- Pre-verification mode
- Streaming verification mode
- HEAD request headers
- Verification failure handling
"""
import pytest
import hashlib
import base64
import io
# =============================================================================
# Test Fixtures
# =============================================================================
@pytest.fixture
def upload_test_file(integration_client):
"""
Factory fixture to upload a test file and return its artifact ID.
Usage:
artifact_id = upload_test_file(project, package, content, tag="v1.0")
"""
def _upload(project_name: str, package_name: str, content: bytes, tag: str = None):
files = {
"file": ("test-file.bin", io.BytesIO(content), "application/octet-stream")
}
data = {}
if tag:
data["tag"] = tag
response = integration_client.post(
f"/api/v1/project/{project_name}/{package_name}/upload",
files=files,
data=data,
)
assert response.status_code == 200, f"Upload failed: {response.text}"
return response.json()["artifact_id"]
return _upload
# =============================================================================
# Integration Tests - Download Headers
# =============================================================================
class TestDownloadChecksumHeaders:
"""Tests for checksum headers in download responses."""
@pytest.mark.integration
def test_download_includes_sha256_header(
self, integration_client, test_package, upload_test_file
):
"""Test download response includes X-Checksum-SHA256 header."""
project_name, package_name = test_package
content = b"Content for SHA256 header test"
# Upload file
artifact_id = upload_test_file(
project_name, package_name, content, tag="sha256-header-test"
)
# Download with proxy mode
response = integration_client.get(
f"/api/v1/project/{project_name}/{package_name}/+/sha256-header-test",
params={"mode": "proxy"},
)
assert response.status_code == 200
assert "X-Checksum-SHA256" in response.headers
assert response.headers["X-Checksum-SHA256"] == artifact_id
@pytest.mark.integration
def test_download_includes_etag_header(
self, integration_client, test_package, upload_test_file
):
"""Test download response includes ETag header."""
project_name, package_name = test_package
content = b"Content for ETag header test"
artifact_id = upload_test_file(
project_name, package_name, content, tag="etag-test"
)
response = integration_client.get(
f"/api/v1/project/{project_name}/{package_name}/+/etag-test",
params={"mode": "proxy"},
)
assert response.status_code == 200
assert "ETag" in response.headers
# ETag should be quoted artifact ID
assert response.headers["ETag"] == f'"{artifact_id}"'
@pytest.mark.integration
def test_download_includes_digest_header(
self, integration_client, test_package, upload_test_file
):
"""Test download response includes RFC 3230 Digest header."""
project_name, package_name = test_package
content = b"Content for Digest header test"
sha256 = hashlib.sha256(content).hexdigest()
upload_test_file(project_name, package_name, content, tag="digest-test")
response = integration_client.get(
f"/api/v1/project/{project_name}/{package_name}/+/digest-test",
params={"mode": "proxy"},
)
assert response.status_code == 200
assert "Digest" in response.headers
# Verify Digest format: sha-256=<base64>
digest = response.headers["Digest"]
assert digest.startswith("sha-256=")
# Verify base64 content matches
b64_hash = digest.split("=", 1)[1]
decoded = base64.b64decode(b64_hash)
assert decoded == bytes.fromhex(sha256)
@pytest.mark.integration
def test_download_includes_content_length_header(
self, integration_client, test_package, upload_test_file
):
"""Test download response includes X-Content-Length header."""
project_name, package_name = test_package
content = b"Content for X-Content-Length test"
upload_test_file(project_name, package_name, content, tag="content-length-test")
response = integration_client.get(
f"/api/v1/project/{project_name}/{package_name}/+/content-length-test",
params={"mode": "proxy"},
)
assert response.status_code == 200
assert "X-Content-Length" in response.headers
assert response.headers["X-Content-Length"] == str(len(content))
@pytest.mark.integration
def test_download_includes_verified_header_false(
self, integration_client, test_package, upload_test_file
):
"""Test download without verification has X-Verified: false."""
project_name, package_name = test_package
content = b"Content for X-Verified false test"
upload_test_file(project_name, package_name, content, tag="verified-false-test")
response = integration_client.get(
f"/api/v1/project/{project_name}/{package_name}/+/verified-false-test",
params={"mode": "proxy", "verify": "false"},
)
assert response.status_code == 200
assert "X-Verified" in response.headers
assert response.headers["X-Verified"] == "false"
# =============================================================================
# Integration Tests - Pre-Verification Mode
# =============================================================================
class TestPreVerificationMode:
"""Tests for pre-verification download mode."""
@pytest.mark.integration
def test_pre_verify_success(
self, integration_client, test_package, upload_test_file
):
"""Test pre-verification mode succeeds for valid content."""
project_name, package_name = test_package
content = b"Content for pre-verification success test"
upload_test_file(project_name, package_name, content, tag="pre-verify-success")
response = integration_client.get(
f"/api/v1/project/{project_name}/{package_name}/+/pre-verify-success",
params={"mode": "proxy", "verify": "true", "verify_mode": "pre"},
)
assert response.status_code == 200
assert response.content == content
assert "X-Verified" in response.headers
assert response.headers["X-Verified"] == "true"
@pytest.mark.integration
def test_pre_verify_returns_complete_content(
self, integration_client, test_package, upload_test_file
):
"""Test pre-verification returns complete content."""
project_name, package_name = test_package
# Use binary content to verify no corruption
content = bytes(range(256)) * 10 # 2560 bytes of all byte values
upload_test_file(project_name, package_name, content, tag="pre-verify-content")
response = integration_client.get(
f"/api/v1/project/{project_name}/{package_name}/+/pre-verify-content",
params={"mode": "proxy", "verify": "true", "verify_mode": "pre"},
)
assert response.status_code == 200
assert response.content == content
# =============================================================================
# Integration Tests - Streaming Verification Mode
# =============================================================================
class TestStreamingVerificationMode:
"""Tests for streaming verification download mode."""
@pytest.mark.integration
def test_stream_verify_success(
self, integration_client, test_package, upload_test_file
):
"""Test streaming verification mode succeeds for valid content."""
project_name, package_name = test_package
content = b"Content for streaming verification success test"
upload_test_file(
project_name, package_name, content, tag="stream-verify-success"
)
response = integration_client.get(
f"/api/v1/project/{project_name}/{package_name}/+/stream-verify-success",
params={"mode": "proxy", "verify": "true", "verify_mode": "stream"},
)
assert response.status_code == 200
assert response.content == content
# X-Verified is "pending" for streaming mode (verified after transfer)
assert "X-Verified" in response.headers
@pytest.mark.integration
def test_stream_verify_large_content(
self, integration_client, test_package, upload_test_file
):
"""Test streaming verification with larger content."""
project_name, package_name = test_package
# 100KB of content
content = b"x" * (100 * 1024)
upload_test_file(project_name, package_name, content, tag="stream-verify-large")
response = integration_client.get(
f"/api/v1/project/{project_name}/{package_name}/+/stream-verify-large",
params={"mode": "proxy", "verify": "true", "verify_mode": "stream"},
)
assert response.status_code == 200
assert response.content == content
# =============================================================================
# Integration Tests - HEAD Request Headers
# =============================================================================
class TestHeadRequestHeaders:
"""Tests for HEAD request checksum headers."""
@pytest.mark.integration
def test_head_includes_sha256_header(
self, integration_client, test_package, upload_test_file
):
"""Test HEAD request includes X-Checksum-SHA256 header."""
project_name, package_name = test_package
content = b"Content for HEAD SHA256 test"
artifact_id = upload_test_file(
project_name, package_name, content, tag="head-sha256-test"
)
response = integration_client.head(
f"/api/v1/project/{project_name}/{package_name}/+/head-sha256-test"
)
assert response.status_code == 200
assert "X-Checksum-SHA256" in response.headers
assert response.headers["X-Checksum-SHA256"] == artifact_id
@pytest.mark.integration
def test_head_includes_etag(
self, integration_client, test_package, upload_test_file
):
"""Test HEAD request includes ETag header."""
project_name, package_name = test_package
content = b"Content for HEAD ETag test"
artifact_id = upload_test_file(
project_name, package_name, content, tag="head-etag-test"
)
response = integration_client.head(
f"/api/v1/project/{project_name}/{package_name}/+/head-etag-test"
)
assert response.status_code == 200
assert "ETag" in response.headers
assert response.headers["ETag"] == f'"{artifact_id}"'
@pytest.mark.integration
def test_head_includes_digest(
self, integration_client, test_package, upload_test_file
):
"""Test HEAD request includes Digest header."""
project_name, package_name = test_package
content = b"Content for HEAD Digest test"
upload_test_file(project_name, package_name, content, tag="head-digest-test")
response = integration_client.head(
f"/api/v1/project/{project_name}/{package_name}/+/head-digest-test"
)
assert response.status_code == 200
assert "Digest" in response.headers
assert response.headers["Digest"].startswith("sha-256=")
@pytest.mark.integration
def test_head_includes_content_length(
self, integration_client, test_package, upload_test_file
):
"""Test HEAD request includes X-Content-Length header."""
project_name, package_name = test_package
content = b"Content for HEAD Content-Length test"
upload_test_file(project_name, package_name, content, tag="head-length-test")
response = integration_client.head(
f"/api/v1/project/{project_name}/{package_name}/+/head-length-test"
)
assert response.status_code == 200
assert "X-Content-Length" in response.headers
assert response.headers["X-Content-Length"] == str(len(content))
@pytest.mark.integration
def test_head_no_body(self, integration_client, test_package, upload_test_file):
"""Test HEAD request returns no body."""
project_name, package_name = test_package
content = b"Content for HEAD no-body test"
upload_test_file(project_name, package_name, content, tag="head-no-body-test")
response = integration_client.head(
f"/api/v1/project/{project_name}/{package_name}/+/head-no-body-test"
)
assert response.status_code == 200
assert response.content == b""
# =============================================================================
# Integration Tests - Range Requests
# =============================================================================
class TestRangeRequestHeaders:
"""Tests for range request handling with checksum headers."""
@pytest.mark.integration
def test_range_request_includes_checksum_headers(
self, integration_client, test_package, upload_test_file
):
"""Test range request includes checksum headers."""
project_name, package_name = test_package
content = b"Content for range request checksum header test"
upload_test_file(project_name, package_name, content, tag="range-checksum-test")
response = integration_client.get(
f"/api/v1/project/{project_name}/{package_name}/+/range-checksum-test",
headers={"Range": "bytes=0-9"},
params={"mode": "proxy"},
)
assert response.status_code == 206
assert "X-Checksum-SHA256" in response.headers
# Checksum is for the FULL file, not the range
assert len(response.headers["X-Checksum-SHA256"]) == 64
# =============================================================================
# Integration Tests - Client-Side Verification
# =============================================================================
class TestClientSideVerification:
"""Tests demonstrating client-side verification using headers."""
@pytest.mark.integration
def test_client_can_verify_downloaded_content(
self, integration_client, test_package, upload_test_file
):
"""Test client can verify downloaded content using header."""
project_name, package_name = test_package
content = b"Content for client-side verification test"
upload_test_file(project_name, package_name, content, tag="client-verify-test")
response = integration_client.get(
f"/api/v1/project/{project_name}/{package_name}/+/client-verify-test",
params={"mode": "proxy"},
)
assert response.status_code == 200
# Get expected hash from header
expected_hash = response.headers["X-Checksum-SHA256"]
# Compute actual hash of downloaded content
actual_hash = hashlib.sha256(response.content).hexdigest()
# Verify match
assert actual_hash == expected_hash
@pytest.mark.integration
def test_client_can_verify_using_digest_header(
self, integration_client, test_package, upload_test_file
):
"""Test client can verify using RFC 3230 Digest header."""
project_name, package_name = test_package
content = b"Content for Digest header verification"
upload_test_file(project_name, package_name, content, tag="digest-verify-test")
response = integration_client.get(
f"/api/v1/project/{project_name}/{package_name}/+/digest-verify-test",
params={"mode": "proxy"},
)
assert response.status_code == 200
# Parse Digest header
digest_header = response.headers["Digest"]
assert digest_header.startswith("sha-256=")
b64_hash = digest_header.split("=", 1)[1]
expected_hash_bytes = base64.b64decode(b64_hash)
# Compute actual hash of downloaded content
actual_hash_bytes = hashlib.sha256(response.content).digest()
# Verify match
assert actual_hash_bytes == expected_hash_bytes