Add download verification with SHA256 checksum support (#26, #27, #28, #29)

This commit is contained in:
Mondo Diaz
2026-01-07 13:36:46 -06:00
parent 08dce6cbb8
commit 35fda65d38
8 changed files with 2157 additions and 12 deletions

477
backend/app/checksum.py Normal file
View File

@@ -0,0 +1,477 @@
"""
Checksum utilities for download verification.
This module provides functions and classes for computing and verifying
SHA256 checksums during artifact downloads.
Key components:
- compute_sha256(): Compute SHA256 of bytes content
- compute_sha256_stream(): Compute SHA256 from an iterable stream
- HashingStreamWrapper: Wrapper that computes hash while streaming
- VerifyingStreamWrapper: Wrapper that verifies hash after streaming
- verify_checksum(): Verify content against expected hash
- ChecksumMismatchError: Exception for verification failures
"""
import hashlib
import logging
import re
import base64
from typing import (
Generator,
Optional,
Any,
Callable,
)
logger = logging.getLogger(__name__)
# Default chunk size for streaming operations (8KB)
DEFAULT_CHUNK_SIZE = 8 * 1024
# Regex pattern for valid SHA256 hash (64 hex characters)
SHA256_PATTERN = re.compile(r"^[a-fA-F0-9]{64}$")
class ChecksumError(Exception):
"""Base exception for checksum operations."""
pass
class ChecksumMismatchError(ChecksumError):
"""
Raised when computed checksum does not match expected checksum.
Attributes:
expected: The expected SHA256 hash
actual: The actual computed SHA256 hash
artifact_id: Optional artifact ID for context
s3_key: Optional S3 key for debugging
size: Optional file size
"""
def __init__(
self,
expected: str,
actual: str,
artifact_id: Optional[str] = None,
s3_key: Optional[str] = None,
size: Optional[int] = None,
message: Optional[str] = None,
):
self.expected = expected
self.actual = actual
self.artifact_id = artifact_id
self.s3_key = s3_key
self.size = size
if message:
self.message = message
else:
self.message = (
f"Checksum verification failed: "
f"expected {expected[:16]}..., got {actual[:16]}..."
)
super().__init__(self.message)
def to_dict(self) -> dict:
"""Convert to dictionary for logging/API responses."""
return {
"error": "checksum_mismatch",
"expected": self.expected,
"actual": self.actual,
"artifact_id": self.artifact_id,
"s3_key": self.s3_key,
"size": self.size,
"message": self.message,
}
class InvalidHashFormatError(ChecksumError):
"""Raised when a hash string is not valid SHA256 format."""
def __init__(self, hash_value: str):
self.hash_value = hash_value
message = f"Invalid SHA256 hash format: '{hash_value[:32]}...'"
super().__init__(message)
def is_valid_sha256(hash_value: str) -> bool:
"""
Check if a string is a valid SHA256 hash (64 hex characters).
Args:
hash_value: String to validate
Returns:
True if valid SHA256 format, False otherwise
"""
if not hash_value:
return False
return bool(SHA256_PATTERN.match(hash_value))
def compute_sha256(content: bytes) -> str:
"""
Compute SHA256 hash of bytes content.
Args:
content: Bytes content to hash
Returns:
Lowercase hexadecimal SHA256 hash (64 characters)
Raises:
ChecksumError: If hash computation fails
"""
if content is None:
raise ChecksumError("Cannot compute hash of None content")
try:
return hashlib.sha256(content).hexdigest().lower()
except Exception as e:
raise ChecksumError(f"Hash computation failed: {e}") from e
def compute_sha256_stream(
stream: Any,
chunk_size: int = DEFAULT_CHUNK_SIZE,
) -> str:
"""
Compute SHA256 hash from a stream or file-like object.
Reads the stream in chunks to minimize memory usage for large files.
Args:
stream: Iterator yielding bytes or file-like object with read()
chunk_size: Size of chunks to read (default 8KB)
Returns:
Lowercase hexadecimal SHA256 hash (64 characters)
Raises:
ChecksumError: If hash computation fails
"""
try:
hasher = hashlib.sha256()
# Handle file-like objects with read()
if hasattr(stream, "read"):
while True:
chunk = stream.read(chunk_size)
if not chunk:
break
hasher.update(chunk)
else:
# Handle iterators
for chunk in stream:
if chunk:
hasher.update(chunk)
return hasher.hexdigest().lower()
except Exception as e:
raise ChecksumError(f"Stream hash computation failed: {e}") from e
def verify_checksum(content: bytes, expected: str) -> bool:
"""
Verify that content matches expected SHA256 hash.
Args:
content: Bytes content to verify
expected: Expected SHA256 hash (case-insensitive)
Returns:
True if hash matches, False otherwise
Raises:
InvalidHashFormatError: If expected hash is not valid format
ChecksumError: If hash computation fails
"""
if not is_valid_sha256(expected):
raise InvalidHashFormatError(expected)
actual = compute_sha256(content)
return actual == expected.lower()
def verify_checksum_strict(
content: bytes,
expected: str,
artifact_id: Optional[str] = None,
s3_key: Optional[str] = None,
) -> None:
"""
Verify content matches expected hash, raising exception on mismatch.
Args:
content: Bytes content to verify
expected: Expected SHA256 hash (case-insensitive)
artifact_id: Optional artifact ID for error context
s3_key: Optional S3 key for error context
Raises:
InvalidHashFormatError: If expected hash is not valid format
ChecksumMismatchError: If verification fails
ChecksumError: If hash computation fails
"""
if not is_valid_sha256(expected):
raise InvalidHashFormatError(expected)
actual = compute_sha256(content)
if actual != expected.lower():
raise ChecksumMismatchError(
expected=expected.lower(),
actual=actual,
artifact_id=artifact_id,
s3_key=s3_key,
size=len(content),
)
def sha256_to_base64(hex_hash: str) -> str:
"""
Convert SHA256 hex string to base64 encoding (for RFC 3230 Digest header).
Args:
hex_hash: SHA256 hash as 64-character hex string
Returns:
Base64-encoded hash string
"""
if not is_valid_sha256(hex_hash):
raise InvalidHashFormatError(hex_hash)
hash_bytes = bytes.fromhex(hex_hash)
return base64.b64encode(hash_bytes).decode("ascii")
class HashingStreamWrapper:
"""
Wrapper that computes SHA256 hash incrementally as chunks are read.
This allows computing the hash while streaming content to a client,
without buffering the entire content in memory.
Usage:
wrapper = HashingStreamWrapper(stream)
for chunk in wrapper:
send_to_client(chunk)
final_hash = wrapper.get_hash()
Attributes:
chunk_size: Size of chunks to yield
bytes_read: Total bytes processed so far
"""
def __init__(
self,
stream: Any,
chunk_size: int = DEFAULT_CHUNK_SIZE,
):
"""
Initialize the hashing stream wrapper.
Args:
stream: Source stream (iterator, file-like, or S3 StreamingBody)
chunk_size: Size of chunks to yield (default 8KB)
"""
self._stream = stream
self._hasher = hashlib.sha256()
self._chunk_size = chunk_size
self._bytes_read = 0
self._finalized = False
self._final_hash: Optional[str] = None
@property
def bytes_read(self) -> int:
"""Total bytes read so far."""
return self._bytes_read
@property
def chunk_size(self) -> int:
"""Chunk size for reading."""
return self._chunk_size
def __iter__(self) -> Generator[bytes, None, None]:
"""Iterate over chunks, computing hash as we go."""
# Handle S3 StreamingBody (has iter_chunks)
if hasattr(self._stream, "iter_chunks"):
for chunk in self._stream.iter_chunks(chunk_size=self._chunk_size):
if chunk:
self._hasher.update(chunk)
self._bytes_read += len(chunk)
yield chunk
# Handle file-like objects with read()
elif hasattr(self._stream, "read"):
while True:
chunk = self._stream.read(self._chunk_size)
if not chunk:
break
self._hasher.update(chunk)
self._bytes_read += len(chunk)
yield chunk
# Handle iterators
else:
for chunk in self._stream:
if chunk:
self._hasher.update(chunk)
self._bytes_read += len(chunk)
yield chunk
self._finalized = True
self._final_hash = self._hasher.hexdigest().lower()
def get_hash(self) -> str:
"""
Get the computed SHA256 hash.
If stream hasn't been fully consumed, consumes remaining chunks.
Returns:
Lowercase hexadecimal SHA256 hash
"""
if not self._finalized:
# Consume remaining stream
for _ in self:
pass
return self._final_hash or self._hasher.hexdigest().lower()
def get_hash_if_complete(self) -> Optional[str]:
"""
Get hash only if stream has been fully consumed.
Returns:
Hash if complete, None otherwise
"""
if self._finalized:
return self._final_hash
return None
class VerifyingStreamWrapper:
"""
Wrapper that yields chunks and verifies hash after streaming completes.
IMPORTANT: Because HTTP streams cannot be "un-sent", if verification
fails after streaming, the client has already received potentially
corrupt data. This wrapper logs an error but cannot prevent delivery.
For guaranteed verification before delivery, use pre-verification mode
which buffers the entire content first.
Usage:
wrapper = VerifyingStreamWrapper(stream, expected_hash)
for chunk in wrapper:
send_to_client(chunk)
wrapper.verify() # Raises ChecksumMismatchError if failed
"""
def __init__(
self,
stream: Any,
expected_hash: str,
artifact_id: Optional[str] = None,
s3_key: Optional[str] = None,
chunk_size: int = DEFAULT_CHUNK_SIZE,
on_failure: Optional[Callable[[Any], None]] = None,
):
"""
Initialize the verifying stream wrapper.
Args:
stream: Source stream
expected_hash: Expected SHA256 hash to verify against
artifact_id: Optional artifact ID for error context
s3_key: Optional S3 key for error context
chunk_size: Size of chunks to yield
on_failure: Optional callback called on verification failure
"""
if not is_valid_sha256(expected_hash):
raise InvalidHashFormatError(expected_hash)
self._hashing_wrapper = HashingStreamWrapper(stream, chunk_size)
self._expected_hash = expected_hash.lower()
self._artifact_id = artifact_id
self._s3_key = s3_key
self._on_failure = on_failure
self._verified: Optional[bool] = None
@property
def bytes_read(self) -> int:
"""Total bytes read so far."""
return self._hashing_wrapper.bytes_read
@property
def is_verified(self) -> Optional[bool]:
"""
Verification status.
Returns:
True if verified successfully, False if failed, None if not yet complete
"""
return self._verified
def __iter__(self) -> Generator[bytes, None, None]:
"""Iterate over chunks."""
yield from self._hashing_wrapper
def verify(self) -> bool:
"""
Verify the hash after stream is complete.
Must be called after fully consuming the iterator.
Returns:
True if verification passed
Raises:
ChecksumMismatchError: If verification failed
"""
actual_hash = self._hashing_wrapper.get_hash()
if actual_hash == self._expected_hash:
self._verified = True
logger.debug(
f"Verification passed for {self._artifact_id or 'unknown'}: {actual_hash[:16]}..."
)
return True
self._verified = False
error = ChecksumMismatchError(
expected=self._expected_hash,
actual=actual_hash,
artifact_id=self._artifact_id,
s3_key=self._s3_key,
size=self._hashing_wrapper.bytes_read,
)
# Log the failure
logger.error(f"Checksum verification FAILED after streaming: {error.to_dict()}")
# Call failure callback if provided
if self._on_failure:
try:
self._on_failure(error)
except Exception as e:
logger.warning(f"Verification failure callback raised exception: {e}")
raise error
def verify_silent(self) -> bool:
"""
Verify the hash without raising exception.
Returns:
True if verification passed, False otherwise
"""
try:
return self.verify()
except ChecksumMismatchError:
return False
def get_actual_hash(self) -> Optional[str]:
"""Get the actual computed hash (only available after iteration)."""
return self._hashing_wrapper.get_hash_if_complete()

View File

@@ -48,6 +48,10 @@ class Settings(BaseSettings):
3600 # Presigned URL expiry in seconds (default: 1 hour)
)
# Logging settings
log_level: str = "INFO" # DEBUG, INFO, WARNING, ERROR, CRITICAL
log_format: str = "auto" # "json", "standard", or "auto" (json in production)
@property
def database_url(self) -> str:
sslmode = f"?sslmode={self.database_sslmode}" if self.database_sslmode else ""

View File

@@ -0,0 +1,254 @@
"""
Structured logging configuration for Orchard.
This module provides:
- Structured JSON logging for production environments
- Request tracing via X-Request-ID header
- Verification failure logging with context
- Configurable log levels via environment
Usage:
from app.logging_config import setup_logging, get_request_id
setup_logging() # Call once at app startup
request_id = get_request_id() # Get current request's ID
"""
import logging
import json
import sys
import uuid
from datetime import datetime, timezone
from typing import Optional, Any, Dict
from contextvars import ContextVar
from .config import get_settings
# Context variable for request ID (thread-safe)
_request_id_var: ContextVar[Optional[str]] = ContextVar("request_id", default=None)
def get_request_id() -> Optional[str]:
"""Get the current request's ID from context."""
return _request_id_var.get()
def set_request_id(request_id: Optional[str] = None) -> str:
"""
Set the request ID for the current context.
If no ID provided, generates a new UUID.
Returns the request ID that was set.
"""
if request_id is None:
request_id = str(uuid.uuid4())
_request_id_var.set(request_id)
return request_id
def clear_request_id():
"""Clear the request ID from context."""
_request_id_var.set(None)
class JSONFormatter(logging.Formatter):
"""
JSON log formatter for structured logging.
Output format:
{
"timestamp": "2025-01-01T00:00:00.000Z",
"level": "INFO",
"logger": "app.routes",
"message": "Request completed",
"request_id": "abc-123",
"extra": {...}
}
"""
def format(self, record: logging.LogRecord) -> str:
log_entry: Dict[str, Any] = {
"timestamp": datetime.now(timezone.utc).isoformat(),
"level": record.levelname,
"logger": record.name,
"message": record.getMessage(),
}
# Add request ID if available
request_id = get_request_id()
if request_id:
log_entry["request_id"] = request_id
# Add exception info if present
if record.exc_info:
log_entry["exception"] = self.formatException(record.exc_info)
# Add extra fields from record
extra_fields: Dict[str, Any] = {}
for key, value in record.__dict__.items():
if key not in (
"name",
"msg",
"args",
"created",
"filename",
"funcName",
"levelname",
"levelno",
"lineno",
"module",
"msecs",
"pathname",
"process",
"processName",
"relativeCreated",
"stack_info",
"exc_info",
"exc_text",
"thread",
"threadName",
"message",
"asctime",
):
try:
json.dumps(value) # Ensure serializable
extra_fields[key] = value
except (TypeError, ValueError):
extra_fields[key] = str(value)
if extra_fields:
log_entry["extra"] = extra_fields
return json.dumps(log_entry)
class StandardFormatter(logging.Formatter):
"""
Standard log formatter for development.
Output format:
[2025-01-01 00:00:00] INFO [app.routes] [req-abc123] Request completed
"""
def format(self, record: logging.LogRecord) -> str:
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
request_id = get_request_id()
req_str = f" [req-{request_id[:8]}]" if request_id else ""
base_msg = f"[{timestamp}] {record.levelname:5} [{record.name}]{req_str} {record.getMessage()}"
if record.exc_info:
base_msg += "\n" + self.formatException(record.exc_info)
return base_msg
def setup_logging(log_level: Optional[str] = None, json_format: Optional[bool] = None):
"""
Configure logging for the application.
Args:
log_level: Log level (DEBUG, INFO, WARNING, ERROR, CRITICAL).
Defaults to ORCHARD_LOG_LEVEL env var or INFO.
json_format: Use JSON format. Defaults to True in production.
"""
settings = get_settings()
# Determine log level
if log_level is None:
log_level = getattr(settings, "log_level", "INFO")
effective_level = log_level if log_level else "INFO"
level = getattr(logging, effective_level.upper(), logging.INFO)
# Determine format
if json_format is None:
json_format = settings.is_production
# Create handler
handler = logging.StreamHandler(sys.stdout)
handler.setLevel(level)
# Set formatter
if json_format:
handler.setFormatter(JSONFormatter())
else:
handler.setFormatter(StandardFormatter())
# Configure root logger
root_logger = logging.getLogger()
root_logger.setLevel(level)
# Remove existing handlers
root_logger.handlers.clear()
root_logger.addHandler(handler)
# Configure specific loggers
for logger_name in ["app", "uvicorn", "uvicorn.access", "uvicorn.error"]:
logger = logging.getLogger(logger_name)
logger.setLevel(level)
logger.handlers.clear()
logger.addHandler(handler)
logger.propagate = False
# Quiet down noisy loggers
logging.getLogger("botocore").setLevel(logging.WARNING)
logging.getLogger("boto3").setLevel(logging.WARNING)
logging.getLogger("urllib3").setLevel(logging.WARNING)
def log_verification_failure(
logger: logging.Logger,
expected_hash: str,
actual_hash: str,
artifact_id: Optional[str] = None,
s3_key: Optional[str] = None,
project: Optional[str] = None,
package: Optional[str] = None,
size: Optional[int] = None,
user_id: Optional[str] = None,
source_ip: Optional[str] = None,
verification_mode: Optional[str] = None,
):
"""
Log a verification failure with full context.
This creates a structured log entry with all relevant details
for debugging and alerting.
"""
logger.error(
"Checksum verification failed",
extra={
"event": "verification_failure",
"expected_hash": expected_hash,
"actual_hash": actual_hash,
"artifact_id": artifact_id,
"s3_key": s3_key,
"project": project,
"package": package,
"size": size,
"user_id": user_id,
"source_ip": source_ip,
"verification_mode": verification_mode,
"hash_match": expected_hash == actual_hash,
},
)
def log_verification_success(
logger: logging.Logger,
artifact_id: str,
size: Optional[int] = None,
verification_mode: Optional[str] = None,
duration_ms: Optional[float] = None,
):
"""Log a successful verification."""
logger.info(
f"Verification passed for artifact {artifact_id[:16]}...",
extra={
"event": "verification_success",
"artifact_id": artifact_id,
"size": size,
"verification_mode": verification_mode,
"duration_ms": duration_ms,
},
)

View File

@@ -97,6 +97,11 @@ from .schemas import (
)
from .metadata import extract_metadata
from .config import get_settings
from .checksum import (
ChecksumMismatchError,
VerifyingStreamWrapper,
sha256_to_base64,
)
router = APIRouter()
@@ -1777,7 +1782,7 @@ def _resolve_artifact_ref(
return artifact
# Download artifact with range request support and download modes
# Download artifact with range request support, download modes, and verification
@router.get("/api/v1/project/{project_name}/{package_name}/+/{ref}")
def download_artifact(
project_name: str,
@@ -1791,7 +1796,34 @@ def download_artifact(
default=None,
description="Download mode: proxy (stream through backend), redirect (302 to presigned URL), presigned (return JSON with URL)",
),
verify: bool = Query(
default=False,
description="Enable checksum verification during download",
),
verify_mode: Optional[Literal["stream", "pre"]] = Query(
default="stream",
description="Verification mode: 'stream' (verify after streaming, logs error if mismatch), 'pre' (verify before streaming, returns 500 if mismatch)",
),
):
"""
Download an artifact by reference (tag name, artifact:hash, tag:name).
Verification modes:
- verify=false (default): No verification, maximum performance
- verify=true&verify_mode=stream: Compute hash while streaming, verify after completion.
If mismatch, logs error but content already sent.
- verify=true&verify_mode=pre: Download and verify BEFORE streaming to client.
Higher latency but guarantees no corrupt data sent.
Response headers always include:
- X-Checksum-SHA256: The expected SHA256 hash
- X-Content-Length: File size in bytes
- ETag: Artifact ID (SHA256)
- Digest: RFC 3230 format sha-256 hash
When verify=true:
- X-Verified: 'true' if verified, 'false' if verification failed
"""
settings = get_settings()
# Get project and package
@@ -1831,6 +1863,23 @@ def download_artifact(
)
db.commit()
# Build common checksum headers (always included)
checksum_headers = {
"X-Checksum-SHA256": artifact.id,
"X-Content-Length": str(artifact.size),
"ETag": f'"{artifact.id}"',
}
# Add RFC 3230 Digest header
try:
digest_base64 = sha256_to_base64(artifact.id)
checksum_headers["Digest"] = f"sha-256={digest_base64}"
except Exception:
pass # Skip if conversion fails
# Add MD5 checksum if available
if artifact.checksum_md5:
checksum_headers["X-Checksum-MD5"] = artifact.checksum_md5
# Determine download mode (query param overrides server default)
download_mode = mode or settings.download_mode
@@ -1867,7 +1916,7 @@ def download_artifact(
return RedirectResponse(url=presigned_url, status_code=302)
# Proxy mode (default fallback) - stream through backend
# Handle range requests
# Handle range requests (verification not supported for partial downloads)
if range:
stream, content_length, content_range = storage.get_stream(
artifact.s3_key, range
@@ -1877,9 +1926,11 @@ def download_artifact(
"Content-Disposition": f'attachment; filename="{filename}"',
"Accept-Ranges": "bytes",
"Content-Length": str(content_length),
**checksum_headers,
}
if content_range:
headers["Content-Range"] = content_range
# Note: X-Verified not set for range requests (cannot verify partial content)
return StreamingResponse(
stream,
@@ -1888,16 +1939,88 @@ def download_artifact(
headers=headers,
)
# Full download
# Full download with optional verification
base_headers = {
"Content-Disposition": f'attachment; filename="{filename}"',
"Accept-Ranges": "bytes",
**checksum_headers,
}
# Pre-verification mode: verify before streaming
if verify and verify_mode == "pre":
try:
content = storage.get_verified(artifact.s3_key, artifact.id)
return Response(
content=content,
media_type=artifact.content_type or "application/octet-stream",
headers={
**base_headers,
"Content-Length": str(len(content)),
"X-Verified": "true",
},
)
except ChecksumMismatchError as e:
logger.error(
f"Pre-verification failed for artifact {artifact.id[:16]}...: {e.to_dict()}"
)
raise HTTPException(
status_code=500,
detail={
"error": "checksum_verification_failed",
"message": "Downloaded content does not match expected checksum",
"expected": e.expected,
"actual": e.actual,
"artifact_id": artifact.id,
},
)
# Streaming verification mode: verify while/after streaming
if verify and verify_mode == "stream":
verifying_wrapper, content_length, _ = storage.get_stream_verified(
artifact.s3_key, artifact.id
)
def verified_stream():
"""Generator that yields chunks and verifies after completion."""
try:
for chunk in verifying_wrapper:
yield chunk
# After all chunks yielded, verify
try:
verifying_wrapper.verify()
logger.info(
f"Streaming verification passed for artifact {artifact.id[:16]}..."
)
except ChecksumMismatchError as e:
# Content already sent - log error but cannot reject
logger.error(
f"Streaming verification FAILED for artifact {artifact.id[:16]}...: "
f"expected {e.expected[:16]}..., got {e.actual[:16]}..."
)
except Exception as e:
logger.error(f"Error during streaming download: {e}")
raise
return StreamingResponse(
verified_stream(),
media_type=artifact.content_type or "application/octet-stream",
headers={
**base_headers,
"Content-Length": str(content_length),
"X-Verified": "pending", # Verification happens after streaming
},
)
# No verification - direct streaming
stream, content_length, _ = storage.get_stream(artifact.s3_key)
return StreamingResponse(
stream,
media_type=artifact.content_type or "application/octet-stream",
headers={
"Content-Disposition": f'attachment; filename="{filename}"',
"Accept-Ranges": "bytes",
**base_headers,
"Content-Length": str(content_length),
"X-Verified": "false",
},
)
@@ -1975,6 +2098,11 @@ def head_artifact(
db: Session = Depends(get_db),
storage: S3Storage = Depends(get_storage),
):
"""
Get artifact metadata without downloading content.
Returns headers with checksum information for client-side verification.
"""
# Get project and package
project = db.query(Project).filter(Project.name == project_name).first()
if not project:
@@ -1995,15 +2123,32 @@ def head_artifact(
filename = sanitize_filename(artifact.original_name or f"{artifact.id}")
# Build headers with checksum information
headers = {
"Content-Disposition": f'attachment; filename="{filename}"',
"Accept-Ranges": "bytes",
"Content-Length": str(artifact.size),
"X-Artifact-Id": artifact.id,
"X-Checksum-SHA256": artifact.id,
"X-Content-Length": str(artifact.size),
"ETag": f'"{artifact.id}"',
}
# Add RFC 3230 Digest header
try:
digest_base64 = sha256_to_base64(artifact.id)
headers["Digest"] = f"sha-256={digest_base64}"
except Exception:
pass # Skip if conversion fails
# Add MD5 checksum if available
if artifact.checksum_md5:
headers["X-Checksum-MD5"] = artifact.checksum_md5
return Response(
content=b"",
media_type=artifact.content_type or "application/octet-stream",
headers={
"Content-Disposition": f'attachment; filename="{filename}"',
"Accept-Ranges": "bytes",
"Content-Length": str(artifact.size),
"X-Artifact-Id": artifact.id,
},
headers=headers,
)
@@ -2017,9 +2162,19 @@ def download_artifact_compat(
db: Session = Depends(get_db),
storage: S3Storage = Depends(get_storage),
range: Optional[str] = Header(None),
verify: bool = Query(default=False),
verify_mode: Optional[Literal["stream", "pre"]] = Query(default="stream"),
):
return download_artifact(
project_name, package_name, ref, request, db, storage, range
project_name,
package_name,
ref,
request,
db,
storage,
range,
verify=verify,
verify_mode=verify_mode,
)

View File

@@ -22,6 +22,13 @@ from botocore.exceptions import (
)
from .config import get_settings
from .checksum import (
ChecksumMismatchError,
HashingStreamWrapper,
VerifyingStreamWrapper,
compute_sha256,
is_valid_sha256,
)
settings = get_settings()
logger = logging.getLogger(__name__)
@@ -876,6 +883,95 @@ class S3Storage:
logger.error(f"Unexpected error during storage health check: {e}")
return False
def get_verified(self, s3_key: str, expected_hash: str) -> bytes:
"""
Download and verify content matches expected SHA256 hash.
This method downloads the entire content, computes its hash, and
verifies it matches the expected hash before returning.
Args:
s3_key: The S3 storage key of the file
expected_hash: Expected SHA256 hash (64 hex characters)
Returns:
File content as bytes (only if verification passes)
Raises:
ChecksumMismatchError: If computed hash doesn't match expected
ClientError: If S3 operation fails
"""
if not is_valid_sha256(expected_hash):
raise ValueError(f"Invalid SHA256 hash format: {expected_hash}")
content = self.get(s3_key)
actual_hash = compute_sha256(content)
if actual_hash != expected_hash.lower():
raise ChecksumMismatchError(
expected=expected_hash.lower(),
actual=actual_hash,
s3_key=s3_key,
size=len(content),
)
logger.debug(f"Verification passed for {s3_key}: {actual_hash[:16]}...")
return content
def get_stream_verified(
self,
s3_key: str,
expected_hash: str,
range_header: Optional[str] = None,
) -> Tuple[VerifyingStreamWrapper, int, Optional[str]]:
"""
Get a verifying stream wrapper for an object.
Returns a wrapper that computes the hash as chunks are read and
can verify after streaming completes. Note that verification happens
AFTER content has been streamed to the client.
IMPORTANT: For range requests, verification is not supported because
we cannot verify a partial download against the full file hash.
Args:
s3_key: The S3 storage key of the file
expected_hash: Expected SHA256 hash (64 hex characters)
range_header: Optional HTTP Range header (verification disabled if set)
Returns:
Tuple of (VerifyingStreamWrapper, content_length, content_range)
The wrapper has a verify() method to call after streaming.
Raises:
ValueError: If expected_hash is invalid format
ClientError: If S3 operation fails
"""
if not is_valid_sha256(expected_hash):
raise ValueError(f"Invalid SHA256 hash format: {expected_hash}")
# Get the S3 stream
stream, content_length, content_range = self.get_stream(s3_key, range_header)
# For range requests, we cannot verify (partial content)
# Return a HashingStreamWrapper that just tracks bytes without verification
if range_header or content_range:
logger.debug(
f"Range request for {s3_key} - verification disabled (partial content)"
)
# Return a basic hashing wrapper (caller should not verify)
hashing_wrapper = HashingStreamWrapper(stream)
return hashing_wrapper, content_length, content_range
# Create verifying wrapper
verifying_wrapper = VerifyingStreamWrapper(
stream=stream,
expected_hash=expected_hash,
s3_key=s3_key,
)
return verifying_wrapper, content_length, content_range
def verify_integrity(self, s3_key: str, expected_sha256: str) -> bool:
"""
Verify the integrity of a stored object by downloading and re-hashing.