Add download verification with SHA256 checksum support (#26, #27, #29)

- Add checksum.py module with SHA256 utilities and HashingStreamWrapper
- Add ChecksumMismatchError and VerifyingStreamWrapper for verification
- Add get_verified() and get_stream_verified() to storage layer
- Add verify query param and verify_mode (pre/stream) to download API
- Add checksum headers: X-Checksum-SHA256, Digest, ETag, X-Verified
- Add comprehensive unit tests for checksum utilities
- Add integration tests for download verification API
This commit is contained in:
Mondo Diaz
2026-01-07 12:35:55 -06:00
parent 08dce6cbb8
commit 9cb5cff526
5 changed files with 1890 additions and 12 deletions

513
backend/app/checksum.py Normal file
View File

@@ -0,0 +1,513 @@
"""
Checksum utilities for download verification.
This module provides functions and classes for computing and verifying
SHA256 checksums during artifact downloads.
Key components:
- compute_sha256(): Compute SHA256 of bytes content
- compute_sha256_stream(): Compute SHA256 from an iterable stream
- HashingStreamWrapper: Wrapper that computes hash while streaming
- VerifyingStreamWrapper: Wrapper that verifies hash after streaming
- verify_checksum(): Verify content against expected hash
- ChecksumMismatchError: Exception for verification failures
"""
import hashlib
import logging
import re
import base64
from typing import (
Iterator,
Generator,
Optional,
Tuple,
Union,
BinaryIO,
Any,
Callable,
)
logger = logging.getLogger(__name__)
# Default chunk size for streaming operations (8KB)
DEFAULT_CHUNK_SIZE = 8 * 1024
# Regex pattern for valid SHA256 hash (64 hex characters)
SHA256_PATTERN = re.compile(r"^[a-fA-F0-9]{64}$")
class ChecksumError(Exception):
"""Base exception for checksum operations."""
pass
class ChecksumMismatchError(ChecksumError):
"""
Raised when computed checksum does not match expected checksum.
Attributes:
expected: The expected SHA256 hash
actual: The actual computed SHA256 hash
artifact_id: Optional artifact ID for context
s3_key: Optional S3 key for debugging
size: Optional file size
"""
def __init__(
self,
expected: str,
actual: str,
artifact_id: Optional[str] = None,
s3_key: Optional[str] = None,
size: Optional[int] = None,
message: Optional[str] = None,
):
self.expected = expected
self.actual = actual
self.artifact_id = artifact_id
self.s3_key = s3_key
self.size = size
if message:
self.message = message
else:
self.message = (
f"Checksum verification failed: "
f"expected {expected[:16]}..., got {actual[:16]}..."
)
super().__init__(self.message)
def to_dict(self) -> dict:
"""Convert to dictionary for logging/API responses."""
return {
"error": "checksum_mismatch",
"expected": self.expected,
"actual": self.actual,
"artifact_id": self.artifact_id,
"s3_key": self.s3_key,
"size": self.size,
"message": self.message,
}
class InvalidHashFormatError(ChecksumError):
"""Raised when a hash string is not valid SHA256 format."""
def __init__(self, hash_value: str):
self.hash_value = hash_value
message = f"Invalid SHA256 hash format: '{hash_value[:32]}...'"
super().__init__(message)
def is_valid_sha256(hash_value: str) -> bool:
"""
Check if a string is a valid SHA256 hash (64 hex characters).
Args:
hash_value: String to validate
Returns:
True if valid SHA256 format, False otherwise
"""
if not hash_value:
return False
return bool(SHA256_PATTERN.match(hash_value))
def compute_sha256(content: bytes) -> str:
"""
Compute SHA256 hash of bytes content.
Args:
content: Bytes content to hash
Returns:
Lowercase hexadecimal SHA256 hash (64 characters)
Raises:
ChecksumError: If hash computation fails
"""
if content is None:
raise ChecksumError("Cannot compute hash of None content")
try:
return hashlib.sha256(content).hexdigest().lower()
except Exception as e:
raise ChecksumError(f"Hash computation failed: {e}") from e
def compute_sha256_stream(
stream: Any,
chunk_size: int = DEFAULT_CHUNK_SIZE,
) -> str:
"""
Compute SHA256 hash from a stream or file-like object.
Reads the stream in chunks to minimize memory usage for large files.
Args:
stream: Iterator yielding bytes or file-like object with read()
chunk_size: Size of chunks to read (default 8KB)
Returns:
Lowercase hexadecimal SHA256 hash (64 characters)
Raises:
ChecksumError: If hash computation fails
"""
try:
hasher = hashlib.sha256()
# Handle file-like objects with read()
if hasattr(stream, "read"):
while True:
chunk = stream.read(chunk_size)
if not chunk:
break
hasher.update(chunk)
else:
# Handle iterators
for chunk in stream:
if chunk:
hasher.update(chunk)
return hasher.hexdigest().lower()
except Exception as e:
raise ChecksumError(f"Stream hash computation failed: {e}") from e
def verify_checksum(content: bytes, expected: str) -> bool:
"""
Verify that content matches expected SHA256 hash.
Args:
content: Bytes content to verify
expected: Expected SHA256 hash (case-insensitive)
Returns:
True if hash matches, False otherwise
Raises:
InvalidHashFormatError: If expected hash is not valid format
ChecksumError: If hash computation fails
"""
if not is_valid_sha256(expected):
raise InvalidHashFormatError(expected)
actual = compute_sha256(content)
return actual == expected.lower()
def verify_checksum_strict(
content: bytes,
expected: str,
artifact_id: Optional[str] = None,
s3_key: Optional[str] = None,
) -> None:
"""
Verify content matches expected hash, raising exception on mismatch.
Args:
content: Bytes content to verify
expected: Expected SHA256 hash (case-insensitive)
artifact_id: Optional artifact ID for error context
s3_key: Optional S3 key for error context
Raises:
InvalidHashFormatError: If expected hash is not valid format
ChecksumMismatchError: If verification fails
ChecksumError: If hash computation fails
"""
if not is_valid_sha256(expected):
raise InvalidHashFormatError(expected)
actual = compute_sha256(content)
if actual != expected.lower():
raise ChecksumMismatchError(
expected=expected.lower(),
actual=actual,
artifact_id=artifact_id,
s3_key=s3_key,
size=len(content),
)
def sha256_to_base64(hex_hash: str) -> str:
"""
Convert SHA256 hex string to base64 encoding (for RFC 3230 Digest header).
Args:
hex_hash: SHA256 hash as 64-character hex string
Returns:
Base64-encoded hash string
"""
if not is_valid_sha256(hex_hash):
raise InvalidHashFormatError(hex_hash)
hash_bytes = bytes.fromhex(hex_hash)
return base64.b64encode(hash_bytes).decode("ascii")
class HashingStreamWrapper:
"""
Wrapper that computes SHA256 hash incrementally as chunks are read.
This allows computing the hash while streaming content to a client,
without buffering the entire content in memory.
Usage:
wrapper = HashingStreamWrapper(stream)
for chunk in wrapper:
send_to_client(chunk)
final_hash = wrapper.get_hash()
Attributes:
chunk_size: Size of chunks to yield
bytes_read: Total bytes processed so far
"""
def __init__(
self,
stream: Any,
chunk_size: int = DEFAULT_CHUNK_SIZE,
):
"""
Initialize the hashing stream wrapper.
Args:
stream: Source stream (iterator, file-like, or S3 StreamingBody)
chunk_size: Size of chunks to yield (default 8KB)
"""
self._stream = stream
self._hasher = hashlib.sha256()
self._chunk_size = chunk_size
self._bytes_read = 0
self._finalized = False
self._final_hash: Optional[str] = None
@property
def bytes_read(self) -> int:
"""Total bytes read so far."""
return self._bytes_read
@property
def chunk_size(self) -> int:
"""Chunk size for reading."""
return self._chunk_size
def __iter__(self) -> Generator[bytes, None, None]:
"""Iterate over chunks, computing hash as we go."""
# Handle S3 StreamingBody (has iter_chunks)
if hasattr(self._stream, "iter_chunks"):
for chunk in self._stream.iter_chunks(chunk_size=self._chunk_size):
if chunk:
self._hasher.update(chunk)
self._bytes_read += len(chunk)
yield chunk
# Handle file-like objects with read()
elif hasattr(self._stream, "read"):
while True:
chunk = self._stream.read(self._chunk_size)
if not chunk:
break
self._hasher.update(chunk)
self._bytes_read += len(chunk)
yield chunk
# Handle iterators
else:
for chunk in self._stream:
if chunk:
self._hasher.update(chunk)
self._bytes_read += len(chunk)
yield chunk
self._finalized = True
self._final_hash = self._hasher.hexdigest().lower()
def get_hash(self) -> str:
"""
Get the computed SHA256 hash.
If stream hasn't been fully consumed, consumes remaining chunks.
Returns:
Lowercase hexadecimal SHA256 hash
"""
if not self._finalized:
# Consume remaining stream
for _ in self:
pass
return self._final_hash or self._hasher.hexdigest().lower()
def get_hash_if_complete(self) -> Optional[str]:
"""
Get hash only if stream has been fully consumed.
Returns:
Hash if complete, None otherwise
"""
if self._finalized:
return self._final_hash
return None
class VerifyingStreamWrapper:
"""
Wrapper that yields chunks and verifies hash after streaming completes.
IMPORTANT: Because HTTP streams cannot be "un-sent", if verification
fails after streaming, the client has already received potentially
corrupt data. This wrapper logs an error but cannot prevent delivery.
For guaranteed verification before delivery, use pre-verification mode
which buffers the entire content first.
Usage:
wrapper = VerifyingStreamWrapper(stream, expected_hash)
for chunk in wrapper:
send_to_client(chunk)
wrapper.verify() # Raises ChecksumMismatchError if failed
"""
def __init__(
self,
stream: Any,
expected_hash: str,
artifact_id: Optional[str] = None,
s3_key: Optional[str] = None,
chunk_size: int = DEFAULT_CHUNK_SIZE,
on_failure: Optional[Callable[[Any], None]] = None,
):
"""
Initialize the verifying stream wrapper.
Args:
stream: Source stream
expected_hash: Expected SHA256 hash to verify against
artifact_id: Optional artifact ID for error context
s3_key: Optional S3 key for error context
chunk_size: Size of chunks to yield
on_failure: Optional callback called on verification failure
"""
if not is_valid_sha256(expected_hash):
raise InvalidHashFormatError(expected_hash)
self._hashing_wrapper = HashingStreamWrapper(stream, chunk_size)
self._expected_hash = expected_hash.lower()
self._artifact_id = artifact_id
self._s3_key = s3_key
self._on_failure = on_failure
self._verified: Optional[bool] = None
@property
def bytes_read(self) -> int:
"""Total bytes read so far."""
return self._hashing_wrapper.bytes_read
@property
def is_verified(self) -> Optional[bool]:
"""
Verification status.
Returns:
True if verified successfully, False if failed, None if not yet complete
"""
return self._verified
def __iter__(self) -> Generator[bytes, None, None]:
"""Iterate over chunks."""
yield from self._hashing_wrapper
def verify(self) -> bool:
"""
Verify the hash after stream is complete.
Must be called after fully consuming the iterator.
Returns:
True if verification passed
Raises:
ChecksumMismatchError: If verification failed
"""
actual_hash = self._hashing_wrapper.get_hash()
if actual_hash == self._expected_hash:
self._verified = True
logger.debug(
f"Verification passed for {self._artifact_id or 'unknown'}: {actual_hash[:16]}..."
)
return True
self._verified = False
error = ChecksumMismatchError(
expected=self._expected_hash,
actual=actual_hash,
artifact_id=self._artifact_id,
s3_key=self._s3_key,
size=self._hashing_wrapper.bytes_read,
)
# Log the failure
logger.error(f"Checksum verification FAILED after streaming: {error.to_dict()}")
# Call failure callback if provided
if self._on_failure:
try:
self._on_failure(error)
except Exception as e:
logger.warning(f"Verification failure callback raised exception: {e}")
raise error
def verify_silent(self) -> bool:
"""
Verify the hash without raising exception.
Returns:
True if verification passed, False otherwise
"""
try:
return self.verify()
except ChecksumMismatchError:
return False
def get_actual_hash(self) -> Optional[str]:
"""Get the actual computed hash (only available after iteration)."""
return self._hashing_wrapper.get_hash_if_complete()
def create_verifying_generator(
stream: Any,
expected_hash: str,
artifact_id: Optional[str] = None,
s3_key: Optional[str] = None,
chunk_size: int = DEFAULT_CHUNK_SIZE,
) -> Tuple[Generator[bytes, None, None], "VerifyingStreamWrapper"]:
"""
Create a verifying generator and its wrapper for manual verification.
This is useful when you need to iterate and later verify manually.
Args:
stream: Source stream
expected_hash: Expected SHA256 hash
artifact_id: Optional artifact ID
s3_key: Optional S3 key
chunk_size: Chunk size for reading
Returns:
Tuple of (generator, wrapper) where wrapper has verify() method
"""
wrapper = VerifyingStreamWrapper(
stream=stream,
expected_hash=expected_hash,
artifact_id=artifact_id,
s3_key=s3_key,
chunk_size=chunk_size,
)
return iter(wrapper), wrapper

View File

@@ -97,6 +97,11 @@ from .schemas import (
) )
from .metadata import extract_metadata from .metadata import extract_metadata
from .config import get_settings from .config import get_settings
from .checksum import (
ChecksumMismatchError,
VerifyingStreamWrapper,
sha256_to_base64,
)
router = APIRouter() router = APIRouter()
@@ -1777,7 +1782,7 @@ def _resolve_artifact_ref(
return artifact return artifact
# Download artifact with range request support and download modes # Download artifact with range request support, download modes, and verification
@router.get("/api/v1/project/{project_name}/{package_name}/+/{ref}") @router.get("/api/v1/project/{project_name}/{package_name}/+/{ref}")
def download_artifact( def download_artifact(
project_name: str, project_name: str,
@@ -1791,7 +1796,34 @@ def download_artifact(
default=None, default=None,
description="Download mode: proxy (stream through backend), redirect (302 to presigned URL), presigned (return JSON with URL)", description="Download mode: proxy (stream through backend), redirect (302 to presigned URL), presigned (return JSON with URL)",
), ),
verify: bool = Query(
default=False,
description="Enable checksum verification during download",
),
verify_mode: Optional[Literal["stream", "pre"]] = Query(
default="stream",
description="Verification mode: 'stream' (verify after streaming, logs error if mismatch), 'pre' (verify before streaming, returns 500 if mismatch)",
),
): ):
"""
Download an artifact by reference (tag name, artifact:hash, tag:name).
Verification modes:
- verify=false (default): No verification, maximum performance
- verify=true&verify_mode=stream: Compute hash while streaming, verify after completion.
If mismatch, logs error but content already sent.
- verify=true&verify_mode=pre: Download and verify BEFORE streaming to client.
Higher latency but guarantees no corrupt data sent.
Response headers always include:
- X-Checksum-SHA256: The expected SHA256 hash
- X-Content-Length: File size in bytes
- ETag: Artifact ID (SHA256)
- Digest: RFC 3230 format sha-256 hash
When verify=true:
- X-Verified: 'true' if verified, 'false' if verification failed
"""
settings = get_settings() settings = get_settings()
# Get project and package # Get project and package
@@ -1831,6 +1863,23 @@ def download_artifact(
) )
db.commit() db.commit()
# Build common checksum headers (always included)
checksum_headers = {
"X-Checksum-SHA256": artifact.id,
"X-Content-Length": str(artifact.size),
"ETag": f'"{artifact.id}"',
}
# Add RFC 3230 Digest header
try:
digest_base64 = sha256_to_base64(artifact.id)
checksum_headers["Digest"] = f"sha-256={digest_base64}"
except Exception:
pass # Skip if conversion fails
# Add MD5 checksum if available
if artifact.checksum_md5:
checksum_headers["X-Checksum-MD5"] = artifact.checksum_md5
# Determine download mode (query param overrides server default) # Determine download mode (query param overrides server default)
download_mode = mode or settings.download_mode download_mode = mode or settings.download_mode
@@ -1867,7 +1916,7 @@ def download_artifact(
return RedirectResponse(url=presigned_url, status_code=302) return RedirectResponse(url=presigned_url, status_code=302)
# Proxy mode (default fallback) - stream through backend # Proxy mode (default fallback) - stream through backend
# Handle range requests # Handle range requests (verification not supported for partial downloads)
if range: if range:
stream, content_length, content_range = storage.get_stream( stream, content_length, content_range = storage.get_stream(
artifact.s3_key, range artifact.s3_key, range
@@ -1877,9 +1926,11 @@ def download_artifact(
"Content-Disposition": f'attachment; filename="{filename}"', "Content-Disposition": f'attachment; filename="{filename}"',
"Accept-Ranges": "bytes", "Accept-Ranges": "bytes",
"Content-Length": str(content_length), "Content-Length": str(content_length),
**checksum_headers,
} }
if content_range: if content_range:
headers["Content-Range"] = content_range headers["Content-Range"] = content_range
# Note: X-Verified not set for range requests (cannot verify partial content)
return StreamingResponse( return StreamingResponse(
stream, stream,
@@ -1888,16 +1939,88 @@ def download_artifact(
headers=headers, headers=headers,
) )
# Full download # Full download with optional verification
base_headers = {
"Content-Disposition": f'attachment; filename="{filename}"',
"Accept-Ranges": "bytes",
**checksum_headers,
}
# Pre-verification mode: verify before streaming
if verify and verify_mode == "pre":
try:
content = storage.get_verified(artifact.s3_key, artifact.id)
return Response(
content=content,
media_type=artifact.content_type or "application/octet-stream",
headers={
**base_headers,
"Content-Length": str(len(content)),
"X-Verified": "true",
},
)
except ChecksumMismatchError as e:
logger.error(
f"Pre-verification failed for artifact {artifact.id[:16]}...: {e.to_dict()}"
)
raise HTTPException(
status_code=500,
detail={
"error": "checksum_verification_failed",
"message": "Downloaded content does not match expected checksum",
"expected": e.expected,
"actual": e.actual,
"artifact_id": artifact.id,
},
)
# Streaming verification mode: verify while/after streaming
if verify and verify_mode == "stream":
verifying_wrapper, content_length, _ = storage.get_stream_verified(
artifact.s3_key, artifact.id
)
def verified_stream():
"""Generator that yields chunks and verifies after completion."""
try:
for chunk in verifying_wrapper:
yield chunk
# After all chunks yielded, verify
try:
verifying_wrapper.verify()
logger.info(
f"Streaming verification passed for artifact {artifact.id[:16]}..."
)
except ChecksumMismatchError as e:
# Content already sent - log error but cannot reject
logger.error(
f"Streaming verification FAILED for artifact {artifact.id[:16]}...: "
f"expected {e.expected[:16]}..., got {e.actual[:16]}..."
)
except Exception as e:
logger.error(f"Error during streaming download: {e}")
raise
return StreamingResponse(
verified_stream(),
media_type=artifact.content_type or "application/octet-stream",
headers={
**base_headers,
"Content-Length": str(content_length),
"X-Verified": "pending", # Verification happens after streaming
},
)
# No verification - direct streaming
stream, content_length, _ = storage.get_stream(artifact.s3_key) stream, content_length, _ = storage.get_stream(artifact.s3_key)
return StreamingResponse( return StreamingResponse(
stream, stream,
media_type=artifact.content_type or "application/octet-stream", media_type=artifact.content_type or "application/octet-stream",
headers={ headers={
"Content-Disposition": f'attachment; filename="{filename}"', **base_headers,
"Accept-Ranges": "bytes",
"Content-Length": str(content_length), "Content-Length": str(content_length),
"X-Verified": "false",
}, },
) )
@@ -1975,6 +2098,11 @@ def head_artifact(
db: Session = Depends(get_db), db: Session = Depends(get_db),
storage: S3Storage = Depends(get_storage), storage: S3Storage = Depends(get_storage),
): ):
"""
Get artifact metadata without downloading content.
Returns headers with checksum information for client-side verification.
"""
# Get project and package # Get project and package
project = db.query(Project).filter(Project.name == project_name).first() project = db.query(Project).filter(Project.name == project_name).first()
if not project: if not project:
@@ -1995,15 +2123,32 @@ def head_artifact(
filename = sanitize_filename(artifact.original_name or f"{artifact.id}") filename = sanitize_filename(artifact.original_name or f"{artifact.id}")
# Build headers with checksum information
headers = {
"Content-Disposition": f'attachment; filename="{filename}"',
"Accept-Ranges": "bytes",
"Content-Length": str(artifact.size),
"X-Artifact-Id": artifact.id,
"X-Checksum-SHA256": artifact.id,
"X-Content-Length": str(artifact.size),
"ETag": f'"{artifact.id}"',
}
# Add RFC 3230 Digest header
try:
digest_base64 = sha256_to_base64(artifact.id)
headers["Digest"] = f"sha-256={digest_base64}"
except Exception:
pass # Skip if conversion fails
# Add MD5 checksum if available
if artifact.checksum_md5:
headers["X-Checksum-MD5"] = artifact.checksum_md5
return Response( return Response(
content=b"", content=b"",
media_type=artifact.content_type or "application/octet-stream", media_type=artifact.content_type or "application/octet-stream",
headers={ headers=headers,
"Content-Disposition": f'attachment; filename="{filename}"',
"Accept-Ranges": "bytes",
"Content-Length": str(artifact.size),
"X-Artifact-Id": artifact.id,
},
) )
@@ -2017,9 +2162,19 @@ def download_artifact_compat(
db: Session = Depends(get_db), db: Session = Depends(get_db),
storage: S3Storage = Depends(get_storage), storage: S3Storage = Depends(get_storage),
range: Optional[str] = Header(None), range: Optional[str] = Header(None),
verify: bool = Query(default=False),
verify_mode: Optional[Literal["stream", "pre"]] = Query(default="stream"),
): ):
return download_artifact( return download_artifact(
project_name, package_name, ref, request, db, storage, range project_name,
package_name,
ref,
request,
db,
storage,
range,
verify=verify,
verify_mode=verify_mode,
) )

View File

@@ -22,6 +22,13 @@ from botocore.exceptions import (
) )
from .config import get_settings from .config import get_settings
from .checksum import (
ChecksumMismatchError,
HashingStreamWrapper,
VerifyingStreamWrapper,
compute_sha256,
is_valid_sha256,
)
settings = get_settings() settings = get_settings()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -876,6 +883,95 @@ class S3Storage:
logger.error(f"Unexpected error during storage health check: {e}") logger.error(f"Unexpected error during storage health check: {e}")
return False return False
def get_verified(self, s3_key: str, expected_hash: str) -> bytes:
"""
Download and verify content matches expected SHA256 hash.
This method downloads the entire content, computes its hash, and
verifies it matches the expected hash before returning.
Args:
s3_key: The S3 storage key of the file
expected_hash: Expected SHA256 hash (64 hex characters)
Returns:
File content as bytes (only if verification passes)
Raises:
ChecksumMismatchError: If computed hash doesn't match expected
ClientError: If S3 operation fails
"""
if not is_valid_sha256(expected_hash):
raise ValueError(f"Invalid SHA256 hash format: {expected_hash}")
content = self.get(s3_key)
actual_hash = compute_sha256(content)
if actual_hash != expected_hash.lower():
raise ChecksumMismatchError(
expected=expected_hash.lower(),
actual=actual_hash,
s3_key=s3_key,
size=len(content),
)
logger.debug(f"Verification passed for {s3_key}: {actual_hash[:16]}...")
return content
def get_stream_verified(
self,
s3_key: str,
expected_hash: str,
range_header: Optional[str] = None,
) -> Tuple[VerifyingStreamWrapper, int, Optional[str]]:
"""
Get a verifying stream wrapper for an object.
Returns a wrapper that computes the hash as chunks are read and
can verify after streaming completes. Note that verification happens
AFTER content has been streamed to the client.
IMPORTANT: For range requests, verification is not supported because
we cannot verify a partial download against the full file hash.
Args:
s3_key: The S3 storage key of the file
expected_hash: Expected SHA256 hash (64 hex characters)
range_header: Optional HTTP Range header (verification disabled if set)
Returns:
Tuple of (VerifyingStreamWrapper, content_length, content_range)
The wrapper has a verify() method to call after streaming.
Raises:
ValueError: If expected_hash is invalid format
ClientError: If S3 operation fails
"""
if not is_valid_sha256(expected_hash):
raise ValueError(f"Invalid SHA256 hash format: {expected_hash}")
# Get the S3 stream
stream, content_length, content_range = self.get_stream(s3_key, range_header)
# For range requests, we cannot verify (partial content)
# Return a HashingStreamWrapper that just tracks bytes without verification
if range_header or content_range:
logger.debug(
f"Range request for {s3_key} - verification disabled (partial content)"
)
# Return a basic hashing wrapper (caller should not verify)
hashing_wrapper = HashingStreamWrapper(stream)
return hashing_wrapper, content_length, content_range
# Create verifying wrapper
verifying_wrapper = VerifyingStreamWrapper(
stream=stream,
expected_hash=expected_hash,
s3_key=s3_key,
)
return verifying_wrapper, content_length, content_range
def verify_integrity(self, s3_key: str, expected_sha256: str) -> bool: def verify_integrity(self, s3_key: str, expected_sha256: str) -> bool:
""" """
Verify the integrity of a stored object by downloading and re-hashing. Verify the integrity of a stored object by downloading and re-hashing.

View File

@@ -0,0 +1,675 @@
"""
Tests for checksum calculation, verification, and download verification.
This module tests:
- SHA256 hash computation (bytes and streams)
- HashingStreamWrapper incremental hashing
- VerifyingStreamWrapper with verification
- ChecksumMismatchError exception handling
- Download verification API endpoints
"""
import pytest
import hashlib
import io
from typing import Generator
from app.checksum import (
compute_sha256,
compute_sha256_stream,
verify_checksum,
verify_checksum_strict,
is_valid_sha256,
sha256_to_base64,
HashingStreamWrapper,
VerifyingStreamWrapper,
ChecksumMismatchError,
ChecksumError,
InvalidHashFormatError,
DEFAULT_CHUNK_SIZE,
)
# =============================================================================
# Test Data
# =============================================================================
# Known test vectors
TEST_CONTENT_HELLO = b"Hello, World!"
TEST_HASH_HELLO = "dffd6021bb2bd5b0af676290809ec3a53191dd81c7f70a4b28688a362182986f"
TEST_CONTENT_EMPTY = b""
TEST_HASH_EMPTY = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
TEST_CONTENT_BINARY = bytes(range(256))
TEST_HASH_BINARY = hashlib.sha256(TEST_CONTENT_BINARY).hexdigest()
# Invalid hashes for testing
INVALID_HASH_TOO_SHORT = "abcd1234"
INVALID_HASH_TOO_LONG = "a" * 65
INVALID_HASH_NON_HEX = "zzzz" + "a" * 60
INVALID_HASH_EMPTY = ""
# =============================================================================
# Unit Tests - SHA256 Computation
# =============================================================================
class TestComputeSHA256:
"""Tests for compute_sha256 function."""
def test_known_content_matches_expected_hash(self):
"""Test SHA256 of known content matches pre-computed hash."""
result = compute_sha256(TEST_CONTENT_HELLO)
assert result == TEST_HASH_HELLO
def test_returns_64_character_hex_string(self):
"""Test result is exactly 64 hex characters."""
result = compute_sha256(TEST_CONTENT_HELLO)
assert len(result) == 64
assert all(c in "0123456789abcdef" for c in result)
def test_returns_lowercase_hex(self):
"""Test result is lowercase."""
result = compute_sha256(TEST_CONTENT_HELLO)
assert result == result.lower()
def test_empty_content_returns_empty_hash(self):
"""Test empty bytes returns SHA256 of empty content."""
result = compute_sha256(TEST_CONTENT_EMPTY)
assert result == TEST_HASH_EMPTY
def test_deterministic_same_input_same_output(self):
"""Test same input always produces same output."""
content = b"test content for determinism"
result1 = compute_sha256(content)
result2 = compute_sha256(content)
assert result1 == result2
def test_different_content_different_hash(self):
"""Test different content produces different hash."""
hash1 = compute_sha256(b"content A")
hash2 = compute_sha256(b"content B")
assert hash1 != hash2
def test_single_bit_change_different_hash(self):
"""Test single bit change produces completely different hash."""
content1 = b"\x00" * 100
content2 = b"\x00" * 99 + b"\x01"
hash1 = compute_sha256(content1)
hash2 = compute_sha256(content2)
assert hash1 != hash2
def test_binary_content(self):
"""Test hashing binary content with all byte values."""
result = compute_sha256(TEST_CONTENT_BINARY)
assert result == TEST_HASH_BINARY
assert len(result) == 64
def test_large_content(self):
"""Test hashing larger content (1MB)."""
large_content = b"x" * (1024 * 1024)
result = compute_sha256(large_content)
expected = hashlib.sha256(large_content).hexdigest()
assert result == expected
def test_none_content_raises_error(self):
"""Test None content raises ChecksumError."""
with pytest.raises(ChecksumError, match="Cannot compute hash of None"):
compute_sha256(None)
class TestComputeSHA256Stream:
"""Tests for compute_sha256_stream function."""
def test_file_like_object(self):
"""Test hashing from file-like object."""
file_obj = io.BytesIO(TEST_CONTENT_HELLO)
result = compute_sha256_stream(file_obj)
assert result == TEST_HASH_HELLO
def test_iterator(self):
"""Test hashing from iterator of chunks."""
def chunk_iterator():
yield b"Hello, "
yield b"World!"
result = compute_sha256_stream(chunk_iterator())
assert result == TEST_HASH_HELLO
def test_various_chunk_sizes_same_result(self):
"""Test different chunk sizes produce same hash."""
content = b"x" * 10000
expected = hashlib.sha256(content).hexdigest()
for chunk_size in [1, 10, 100, 1000, 8192]:
file_obj = io.BytesIO(content)
result = compute_sha256_stream(file_obj, chunk_size=chunk_size)
assert result == expected, f"Failed for chunk_size={chunk_size}"
def test_single_byte_chunks(self):
"""Test with 1-byte chunks (edge case)."""
content = b"ABC"
file_obj = io.BytesIO(content)
result = compute_sha256_stream(file_obj, chunk_size=1)
expected = hashlib.sha256(content).hexdigest()
assert result == expected
def test_empty_stream(self):
"""Test empty stream returns empty content hash."""
file_obj = io.BytesIO(b"")
result = compute_sha256_stream(file_obj)
assert result == TEST_HASH_EMPTY
# =============================================================================
# Unit Tests - Hash Validation
# =============================================================================
class TestIsValidSHA256:
"""Tests for is_valid_sha256 function."""
def test_valid_hash_lowercase(self):
"""Test valid lowercase hash."""
assert is_valid_sha256(TEST_HASH_HELLO) is True
def test_valid_hash_uppercase(self):
"""Test valid uppercase hash."""
assert is_valid_sha256(TEST_HASH_HELLO.upper()) is True
def test_valid_hash_mixed_case(self):
"""Test valid mixed case hash."""
mixed = TEST_HASH_HELLO[:32].upper() + TEST_HASH_HELLO[32:].lower()
assert is_valid_sha256(mixed) is True
def test_invalid_too_short(self):
"""Test hash that's too short."""
assert is_valid_sha256(INVALID_HASH_TOO_SHORT) is False
def test_invalid_too_long(self):
"""Test hash that's too long."""
assert is_valid_sha256(INVALID_HASH_TOO_LONG) is False
def test_invalid_non_hex(self):
"""Test hash with non-hex characters."""
assert is_valid_sha256(INVALID_HASH_NON_HEX) is False
def test_invalid_empty(self):
"""Test empty string."""
assert is_valid_sha256(INVALID_HASH_EMPTY) is False
def test_invalid_none(self):
"""Test None value."""
assert is_valid_sha256(None) is False
class TestSHA256ToBase64:
"""Tests for sha256_to_base64 function."""
def test_converts_to_base64(self):
"""Test conversion to base64."""
result = sha256_to_base64(TEST_HASH_HELLO)
# Verify it's valid base64
import base64
decoded = base64.b64decode(result)
assert len(decoded) == 32 # SHA256 is 32 bytes
def test_invalid_hash_raises_error(self):
"""Test invalid hash raises InvalidHashFormatError."""
with pytest.raises(InvalidHashFormatError):
sha256_to_base64(INVALID_HASH_TOO_SHORT)
# =============================================================================
# Unit Tests - Verification Functions
# =============================================================================
class TestVerifyChecksum:
"""Tests for verify_checksum function."""
def test_matching_checksum_returns_true(self):
"""Test matching checksum returns True."""
result = verify_checksum(TEST_CONTENT_HELLO, TEST_HASH_HELLO)
assert result is True
def test_mismatched_checksum_returns_false(self):
"""Test mismatched checksum returns False."""
wrong_hash = "a" * 64
result = verify_checksum(TEST_CONTENT_HELLO, wrong_hash)
assert result is False
def test_case_insensitive_comparison(self):
"""Test comparison is case-insensitive."""
result = verify_checksum(TEST_CONTENT_HELLO, TEST_HASH_HELLO.upper())
assert result is True
def test_invalid_hash_format_raises_error(self):
"""Test invalid hash format raises error."""
with pytest.raises(InvalidHashFormatError):
verify_checksum(TEST_CONTENT_HELLO, INVALID_HASH_TOO_SHORT)
class TestVerifyChecksumStrict:
"""Tests for verify_checksum_strict function."""
def test_matching_checksum_returns_none(self):
"""Test matching checksum doesn't raise."""
# Should not raise
verify_checksum_strict(TEST_CONTENT_HELLO, TEST_HASH_HELLO)
def test_mismatched_checksum_raises_error(self):
"""Test mismatched checksum raises ChecksumMismatchError."""
wrong_hash = "a" * 64
with pytest.raises(ChecksumMismatchError) as exc_info:
verify_checksum_strict(TEST_CONTENT_HELLO, wrong_hash)
error = exc_info.value
assert error.expected == wrong_hash.lower()
assert error.actual == TEST_HASH_HELLO
assert error.size == len(TEST_CONTENT_HELLO)
def test_error_includes_context(self):
"""Test error includes artifact_id and s3_key context."""
wrong_hash = "a" * 64
with pytest.raises(ChecksumMismatchError) as exc_info:
verify_checksum_strict(
TEST_CONTENT_HELLO,
wrong_hash,
artifact_id="test-artifact-123",
s3_key="fruits/ab/cd/abcd1234...",
)
error = exc_info.value
assert error.artifact_id == "test-artifact-123"
assert error.s3_key == "fruits/ab/cd/abcd1234..."
# =============================================================================
# Unit Tests - HashingStreamWrapper
# =============================================================================
class TestHashingStreamWrapper:
"""Tests for HashingStreamWrapper class."""
def test_computes_correct_hash(self):
"""Test wrapper computes correct hash."""
stream = io.BytesIO(TEST_CONTENT_HELLO)
wrapper = HashingStreamWrapper(stream)
# Consume the stream
chunks = list(wrapper)
# Verify hash
assert wrapper.get_hash() == TEST_HASH_HELLO
def test_yields_correct_chunks(self):
"""Test wrapper yields all content."""
stream = io.BytesIO(TEST_CONTENT_HELLO)
wrapper = HashingStreamWrapper(stream)
chunks = list(wrapper)
content = b"".join(chunks)
assert content == TEST_CONTENT_HELLO
def test_tracks_bytes_read(self):
"""Test bytes_read property tracks correctly."""
stream = io.BytesIO(TEST_CONTENT_HELLO)
wrapper = HashingStreamWrapper(stream)
assert wrapper.bytes_read == 0
list(wrapper) # Consume
assert wrapper.bytes_read == len(TEST_CONTENT_HELLO)
def test_get_hash_before_iteration_consumes_stream(self):
"""Test get_hash() consumes stream if not already done."""
stream = io.BytesIO(TEST_CONTENT_HELLO)
wrapper = HashingStreamWrapper(stream)
# Call get_hash without iterating
hash_result = wrapper.get_hash()
assert hash_result == TEST_HASH_HELLO
assert wrapper.bytes_read == len(TEST_CONTENT_HELLO)
def test_get_hash_if_complete_before_iteration_returns_none(self):
"""Test get_hash_if_complete returns None before iteration."""
stream = io.BytesIO(TEST_CONTENT_HELLO)
wrapper = HashingStreamWrapper(stream)
assert wrapper.get_hash_if_complete() is None
def test_get_hash_if_complete_after_iteration_returns_hash(self):
"""Test get_hash_if_complete returns hash after iteration."""
stream = io.BytesIO(TEST_CONTENT_HELLO)
wrapper = HashingStreamWrapper(stream)
list(wrapper) # Consume
assert wrapper.get_hash_if_complete() == TEST_HASH_HELLO
def test_custom_chunk_size(self):
"""Test custom chunk size is respected."""
content = b"x" * 1000
stream = io.BytesIO(content)
wrapper = HashingStreamWrapper(stream, chunk_size=100)
chunks = list(wrapper)
# Each chunk should be at most 100 bytes
for chunk in chunks[:-1]: # All but last
assert len(chunk) == 100
# Total content should match
assert b"".join(chunks) == content
def test_iterator_interface(self):
"""Test wrapper supports iterator interface."""
stream = io.BytesIO(TEST_CONTENT_HELLO)
wrapper = HashingStreamWrapper(stream)
# Should be able to use for loop
result = b""
for chunk in wrapper:
result += chunk
assert result == TEST_CONTENT_HELLO
# =============================================================================
# Unit Tests - VerifyingStreamWrapper
# =============================================================================
class TestVerifyingStreamWrapper:
"""Tests for VerifyingStreamWrapper class."""
def test_verify_success(self):
"""Test verification succeeds for matching content."""
stream = io.BytesIO(TEST_CONTENT_HELLO)
wrapper = VerifyingStreamWrapper(stream, TEST_HASH_HELLO)
# Consume stream
list(wrapper)
# Verify should succeed
result = wrapper.verify()
assert result is True
assert wrapper.is_verified is True
def test_verify_failure_raises_error(self):
"""Test verification failure raises ChecksumMismatchError."""
stream = io.BytesIO(TEST_CONTENT_HELLO)
wrong_hash = "a" * 64
wrapper = VerifyingStreamWrapper(stream, wrong_hash)
# Consume stream
list(wrapper)
# Verify should fail
with pytest.raises(ChecksumMismatchError):
wrapper.verify()
assert wrapper.is_verified is False
def test_verify_silent_success(self):
"""Test verify_silent returns True on success."""
stream = io.BytesIO(TEST_CONTENT_HELLO)
wrapper = VerifyingStreamWrapper(stream, TEST_HASH_HELLO)
list(wrapper)
result = wrapper.verify_silent()
assert result is True
def test_verify_silent_failure(self):
"""Test verify_silent returns False on failure."""
stream = io.BytesIO(TEST_CONTENT_HELLO)
wrong_hash = "a" * 64
wrapper = VerifyingStreamWrapper(stream, wrong_hash)
list(wrapper)
result = wrapper.verify_silent()
assert result is False
def test_invalid_expected_hash_raises_error(self):
"""Test invalid expected hash raises error at construction."""
stream = io.BytesIO(TEST_CONTENT_HELLO)
with pytest.raises(InvalidHashFormatError):
VerifyingStreamWrapper(stream, INVALID_HASH_TOO_SHORT)
def test_on_failure_callback(self):
"""Test on_failure callback is called on verification failure."""
stream = io.BytesIO(TEST_CONTENT_HELLO)
wrong_hash = "a" * 64
callback_called = []
def on_failure(error):
callback_called.append(error)
wrapper = VerifyingStreamWrapper(stream, wrong_hash, on_failure=on_failure)
list(wrapper)
with pytest.raises(ChecksumMismatchError):
wrapper.verify()
assert len(callback_called) == 1
assert isinstance(callback_called[0], ChecksumMismatchError)
def test_get_actual_hash_after_iteration(self):
"""Test get_actual_hash returns hash after iteration."""
stream = io.BytesIO(TEST_CONTENT_HELLO)
wrapper = VerifyingStreamWrapper(stream, TEST_HASH_HELLO)
# Before iteration
assert wrapper.get_actual_hash() is None
list(wrapper)
# After iteration
assert wrapper.get_actual_hash() == TEST_HASH_HELLO
def test_includes_context_in_error(self):
"""Test error includes artifact_id and s3_key."""
stream = io.BytesIO(TEST_CONTENT_HELLO)
wrong_hash = "a" * 64
wrapper = VerifyingStreamWrapper(
stream,
wrong_hash,
artifact_id="test-artifact",
s3_key="test/key",
)
list(wrapper)
with pytest.raises(ChecksumMismatchError) as exc_info:
wrapper.verify()
error = exc_info.value
assert error.artifact_id == "test-artifact"
assert error.s3_key == "test/key"
# =============================================================================
# Unit Tests - ChecksumMismatchError
# =============================================================================
class TestChecksumMismatchError:
"""Tests for ChecksumMismatchError class."""
def test_to_dict(self):
"""Test to_dict returns proper dictionary."""
error = ChecksumMismatchError(
expected="a" * 64,
actual="b" * 64,
artifact_id="test-123",
s3_key="test/key",
size=1024,
)
result = error.to_dict()
assert result["error"] == "checksum_mismatch"
assert result["expected"] == "a" * 64
assert result["actual"] == "b" * 64
assert result["artifact_id"] == "test-123"
assert result["s3_key"] == "test/key"
assert result["size"] == 1024
def test_message_format(self):
"""Test error message format."""
error = ChecksumMismatchError(
expected="a" * 64,
actual="b" * 64,
)
assert "verification failed" in str(error).lower()
assert "expected" in str(error).lower()
def test_custom_message(self):
"""Test custom message is used."""
error = ChecksumMismatchError(
expected="a" * 64,
actual="b" * 64,
message="Custom error message",
)
assert str(error) == "Custom error message"
# =============================================================================
# Corruption Simulation Tests
# =============================================================================
class TestCorruptionDetection:
"""Tests for detecting corrupted content."""
def test_detect_truncated_content(self):
"""Test detection of truncated content."""
original = TEST_CONTENT_HELLO
truncated = original[:-1] # Remove last byte
original_hash = compute_sha256(original)
truncated_hash = compute_sha256(truncated)
assert original_hash != truncated_hash
assert verify_checksum(truncated, original_hash) is False
def test_detect_extra_bytes(self):
"""Test detection of content with extra bytes."""
original = TEST_CONTENT_HELLO
extended = original + b"\x00" # Add null byte
original_hash = compute_sha256(original)
assert verify_checksum(extended, original_hash) is False
def test_detect_single_bit_flip(self):
"""Test detection of single bit flip."""
original = TEST_CONTENT_HELLO
# Flip first bit of first byte
corrupted = bytes([original[0] ^ 0x01]) + original[1:]
original_hash = compute_sha256(original)
assert verify_checksum(corrupted, original_hash) is False
def test_detect_wrong_content(self):
"""Test detection of completely different content."""
original = TEST_CONTENT_HELLO
different = b"Something completely different"
original_hash = compute_sha256(original)
assert verify_checksum(different, original_hash) is False
def test_detect_empty_vs_nonempty(self):
"""Test detection of empty content vs non-empty."""
original = TEST_CONTENT_HELLO
empty = b""
original_hash = compute_sha256(original)
assert verify_checksum(empty, original_hash) is False
def test_streaming_detection_of_corruption(self):
"""Test VerifyingStreamWrapper detects corruption."""
original = b"Original content that will be corrupted"
original_hash = compute_sha256(original)
# Corrupt the content
corrupted = b"Corrupted content that is different"
stream = io.BytesIO(corrupted)
wrapper = VerifyingStreamWrapper(stream, original_hash)
list(wrapper) # Consume
with pytest.raises(ChecksumMismatchError):
wrapper.verify()
# =============================================================================
# Edge Case Tests
# =============================================================================
class TestEdgeCases:
"""Tests for edge cases and boundary conditions."""
def test_null_bytes_in_content(self):
"""Test content with null bytes."""
content = b"\x00\x00\x00"
hash_result = compute_sha256(content)
assert verify_checksum(content, hash_result) is True
def test_whitespace_only_content(self):
"""Test content with only whitespace."""
content = b" \t\n\r "
hash_result = compute_sha256(content)
assert verify_checksum(content, hash_result) is True
def test_large_content_streaming(self):
"""Test streaming verification of large content."""
# 1MB of content
large_content = b"x" * (1024 * 1024)
expected_hash = compute_sha256(large_content)
stream = io.BytesIO(large_content)
wrapper = VerifyingStreamWrapper(stream, expected_hash)
# Consume and verify
chunks = list(wrapper)
assert wrapper.verify() is True
assert b"".join(chunks) == large_content
def test_unicode_bytes_content(self):
"""Test content with unicode bytes."""
content = "Hello, 世界! 🌍".encode("utf-8")
hash_result = compute_sha256(content)
assert verify_checksum(content, hash_result) is True
def test_maximum_chunk_size_larger_than_content(self):
"""Test chunk size larger than content."""
content = b"small"
stream = io.BytesIO(content)
wrapper = HashingStreamWrapper(stream, chunk_size=1024 * 1024)
chunks = list(wrapper)
assert len(chunks) == 1
assert chunks[0] == content
assert wrapper.get_hash() == compute_sha256(content)

View File

@@ -0,0 +1,439 @@
"""
Integration tests for download verification API endpoints.
These tests verify:
- Checksum headers in download responses
- Pre-verification mode
- Streaming verification mode
- HEAD request headers
- Verification failure handling
"""
import pytest
import hashlib
import base64
# =============================================================================
# Test Fixtures
# =============================================================================
@pytest.fixture
def test_content():
"""Generate test content with known hash."""
content = b"Test content for download verification"
sha256 = hashlib.sha256(content).hexdigest()
return content, sha256
# =============================================================================
# Integration Tests - Download Headers
# =============================================================================
class TestDownloadChecksumHeaders:
"""Tests for checksum headers in download responses."""
@pytest.mark.integration
def test_download_includes_sha256_header(
self, integration_client, test_package, upload_test_file
):
"""Test download response includes X-Checksum-SHA256 header."""
project_name, package_name = test_package
content = b"Content for SHA256 header test"
# Upload file
artifact_id = upload_test_file(
project_name, package_name, content, tag="sha256-header-test"
)
# Download with proxy mode
response = integration_client.get(
f"/api/v1/project/{project_name}/{package_name}/+/sha256-header-test",
params={"mode": "proxy"},
)
assert response.status_code == 200
assert "X-Checksum-SHA256" in response.headers
assert response.headers["X-Checksum-SHA256"] == artifact_id
@pytest.mark.integration
def test_download_includes_etag_header(
self, integration_client, test_package, upload_test_file
):
"""Test download response includes ETag header."""
project_name, package_name = test_package
content = b"Content for ETag header test"
artifact_id = upload_test_file(
project_name, package_name, content, tag="etag-test"
)
response = integration_client.get(
f"/api/v1/project/{project_name}/{package_name}/+/etag-test",
params={"mode": "proxy"},
)
assert response.status_code == 200
assert "ETag" in response.headers
# ETag should be quoted artifact ID
assert response.headers["ETag"] == f'"{artifact_id}"'
@pytest.mark.integration
def test_download_includes_digest_header(
self, integration_client, test_package, upload_test_file
):
"""Test download response includes RFC 3230 Digest header."""
project_name, package_name = test_package
content = b"Content for Digest header test"
sha256 = hashlib.sha256(content).hexdigest()
upload_test_file(project_name, package_name, content, tag="digest-test")
response = integration_client.get(
f"/api/v1/project/{project_name}/{package_name}/+/digest-test",
params={"mode": "proxy"},
)
assert response.status_code == 200
assert "Digest" in response.headers
# Verify Digest format: sha-256=<base64>
digest = response.headers["Digest"]
assert digest.startswith("sha-256=")
# Verify base64 content matches
b64_hash = digest.split("=", 1)[1]
decoded = base64.b64decode(b64_hash)
assert decoded == bytes.fromhex(sha256)
@pytest.mark.integration
def test_download_includes_content_length_header(
self, integration_client, test_package, upload_test_file
):
"""Test download response includes X-Content-Length header."""
project_name, package_name = test_package
content = b"Content for X-Content-Length test"
upload_test_file(project_name, package_name, content, tag="content-length-test")
response = integration_client.get(
f"/api/v1/project/{project_name}/{package_name}/+/content-length-test",
params={"mode": "proxy"},
)
assert response.status_code == 200
assert "X-Content-Length" in response.headers
assert response.headers["X-Content-Length"] == str(len(content))
@pytest.mark.integration
def test_download_includes_verified_header_false(
self, integration_client, test_package, upload_test_file
):
"""Test download without verification has X-Verified: false."""
project_name, package_name = test_package
content = b"Content for X-Verified false test"
upload_test_file(project_name, package_name, content, tag="verified-false-test")
response = integration_client.get(
f"/api/v1/project/{project_name}/{package_name}/+/verified-false-test",
params={"mode": "proxy", "verify": "false"},
)
assert response.status_code == 200
assert "X-Verified" in response.headers
assert response.headers["X-Verified"] == "false"
# =============================================================================
# Integration Tests - Pre-Verification Mode
# =============================================================================
class TestPreVerificationMode:
"""Tests for pre-verification download mode."""
@pytest.mark.integration
def test_pre_verify_success(
self, integration_client, test_package, upload_test_file
):
"""Test pre-verification mode succeeds for valid content."""
project_name, package_name = test_package
content = b"Content for pre-verification success test"
upload_test_file(project_name, package_name, content, tag="pre-verify-success")
response = integration_client.get(
f"/api/v1/project/{project_name}/{package_name}/+/pre-verify-success",
params={"mode": "proxy", "verify": "true", "verify_mode": "pre"},
)
assert response.status_code == 200
assert response.content == content
assert "X-Verified" in response.headers
assert response.headers["X-Verified"] == "true"
@pytest.mark.integration
def test_pre_verify_returns_complete_content(
self, integration_client, test_package, upload_test_file
):
"""Test pre-verification returns complete content."""
project_name, package_name = test_package
# Use binary content to verify no corruption
content = bytes(range(256)) * 10 # 2560 bytes of all byte values
upload_test_file(project_name, package_name, content, tag="pre-verify-content")
response = integration_client.get(
f"/api/v1/project/{project_name}/{package_name}/+/pre-verify-content",
params={"mode": "proxy", "verify": "true", "verify_mode": "pre"},
)
assert response.status_code == 200
assert response.content == content
# =============================================================================
# Integration Tests - Streaming Verification Mode
# =============================================================================
class TestStreamingVerificationMode:
"""Tests for streaming verification download mode."""
@pytest.mark.integration
def test_stream_verify_success(
self, integration_client, test_package, upload_test_file
):
"""Test streaming verification mode succeeds for valid content."""
project_name, package_name = test_package
content = b"Content for streaming verification success test"
upload_test_file(
project_name, package_name, content, tag="stream-verify-success"
)
response = integration_client.get(
f"/api/v1/project/{project_name}/{package_name}/+/stream-verify-success",
params={"mode": "proxy", "verify": "true", "verify_mode": "stream"},
)
assert response.status_code == 200
assert response.content == content
# X-Verified is "pending" for streaming mode (verified after transfer)
assert "X-Verified" in response.headers
@pytest.mark.integration
def test_stream_verify_large_content(
self, integration_client, test_package, upload_test_file
):
"""Test streaming verification with larger content."""
project_name, package_name = test_package
# 100KB of content
content = b"x" * (100 * 1024)
upload_test_file(project_name, package_name, content, tag="stream-verify-large")
response = integration_client.get(
f"/api/v1/project/{project_name}/{package_name}/+/stream-verify-large",
params={"mode": "proxy", "verify": "true", "verify_mode": "stream"},
)
assert response.status_code == 200
assert response.content == content
# =============================================================================
# Integration Tests - HEAD Request Headers
# =============================================================================
class TestHeadRequestHeaders:
"""Tests for HEAD request checksum headers."""
@pytest.mark.integration
def test_head_includes_sha256_header(
self, integration_client, test_package, upload_test_file
):
"""Test HEAD request includes X-Checksum-SHA256 header."""
project_name, package_name = test_package
content = b"Content for HEAD SHA256 test"
artifact_id = upload_test_file(
project_name, package_name, content, tag="head-sha256-test"
)
response = integration_client.head(
f"/api/v1/project/{project_name}/{package_name}/+/head-sha256-test"
)
assert response.status_code == 200
assert "X-Checksum-SHA256" in response.headers
assert response.headers["X-Checksum-SHA256"] == artifact_id
@pytest.mark.integration
def test_head_includes_etag(
self, integration_client, test_package, upload_test_file
):
"""Test HEAD request includes ETag header."""
project_name, package_name = test_package
content = b"Content for HEAD ETag test"
artifact_id = upload_test_file(
project_name, package_name, content, tag="head-etag-test"
)
response = integration_client.head(
f"/api/v1/project/{project_name}/{package_name}/+/head-etag-test"
)
assert response.status_code == 200
assert "ETag" in response.headers
assert response.headers["ETag"] == f'"{artifact_id}"'
@pytest.mark.integration
def test_head_includes_digest(
self, integration_client, test_package, upload_test_file
):
"""Test HEAD request includes Digest header."""
project_name, package_name = test_package
content = b"Content for HEAD Digest test"
upload_test_file(project_name, package_name, content, tag="head-digest-test")
response = integration_client.head(
f"/api/v1/project/{project_name}/{package_name}/+/head-digest-test"
)
assert response.status_code == 200
assert "Digest" in response.headers
assert response.headers["Digest"].startswith("sha-256=")
@pytest.mark.integration
def test_head_includes_content_length(
self, integration_client, test_package, upload_test_file
):
"""Test HEAD request includes X-Content-Length header."""
project_name, package_name = test_package
content = b"Content for HEAD Content-Length test"
upload_test_file(project_name, package_name, content, tag="head-length-test")
response = integration_client.head(
f"/api/v1/project/{project_name}/{package_name}/+/head-length-test"
)
assert response.status_code == 200
assert "X-Content-Length" in response.headers
assert response.headers["X-Content-Length"] == str(len(content))
@pytest.mark.integration
def test_head_no_body(self, integration_client, test_package, upload_test_file):
"""Test HEAD request returns no body."""
project_name, package_name = test_package
content = b"Content for HEAD no-body test"
upload_test_file(project_name, package_name, content, tag="head-no-body-test")
response = integration_client.head(
f"/api/v1/project/{project_name}/{package_name}/+/head-no-body-test"
)
assert response.status_code == 200
assert response.content == b""
# =============================================================================
# Integration Tests - Range Requests
# =============================================================================
class TestRangeRequestHeaders:
"""Tests for range request handling with checksum headers."""
@pytest.mark.integration
def test_range_request_includes_checksum_headers(
self, integration_client, test_package, upload_test_file
):
"""Test range request includes checksum headers."""
project_name, package_name = test_package
content = b"Content for range request checksum header test"
upload_test_file(project_name, package_name, content, tag="range-checksum-test")
response = integration_client.get(
f"/api/v1/project/{project_name}/{package_name}/+/range-checksum-test",
headers={"Range": "bytes=0-9"},
params={"mode": "proxy"},
)
assert response.status_code == 206
assert "X-Checksum-SHA256" in response.headers
# Checksum is for the FULL file, not the range
assert len(response.headers["X-Checksum-SHA256"]) == 64
# =============================================================================
# Integration Tests - Client-Side Verification
# =============================================================================
class TestClientSideVerification:
"""Tests demonstrating client-side verification using headers."""
@pytest.mark.integration
def test_client_can_verify_downloaded_content(
self, integration_client, test_package, upload_test_file
):
"""Test client can verify downloaded content using header."""
project_name, package_name = test_package
content = b"Content for client-side verification test"
upload_test_file(project_name, package_name, content, tag="client-verify-test")
response = integration_client.get(
f"/api/v1/project/{project_name}/{package_name}/+/client-verify-test",
params={"mode": "proxy"},
)
assert response.status_code == 200
# Get expected hash from header
expected_hash = response.headers["X-Checksum-SHA256"]
# Compute actual hash of downloaded content
actual_hash = hashlib.sha256(response.content).hexdigest()
# Verify match
assert actual_hash == expected_hash
@pytest.mark.integration
def test_client_can_verify_using_digest_header(
self, integration_client, test_package, upload_test_file
):
"""Test client can verify using RFC 3230 Digest header."""
project_name, package_name = test_package
content = b"Content for Digest header verification"
upload_test_file(project_name, package_name, content, tag="digest-verify-test")
response = integration_client.get(
f"/api/v1/project/{project_name}/{package_name}/+/digest-verify-test",
params={"mode": "proxy"},
)
assert response.status_code == 200
# Parse Digest header
digest_header = response.headers["Digest"]
assert digest_header.startswith("sha-256=")
b64_hash = digest_header.split("=", 1)[1]
expected_hash_bytes = base64.b64decode(b64_hash)
# Compute actual hash of downloaded content
actual_hash_bytes = hashlib.sha256(response.content).digest()
# Verify match
assert actual_hash_bytes == expected_hash_bytes