This commit is contained in:
477
backend/app/checksum.py
Normal file
477
backend/app/checksum.py
Normal file
@@ -0,0 +1,477 @@
|
||||
"""
|
||||
Checksum utilities for download verification.
|
||||
|
||||
This module provides functions and classes for computing and verifying
|
||||
SHA256 checksums during artifact downloads.
|
||||
|
||||
Key components:
|
||||
- compute_sha256(): Compute SHA256 of bytes content
|
||||
- compute_sha256_stream(): Compute SHA256 from an iterable stream
|
||||
- HashingStreamWrapper: Wrapper that computes hash while streaming
|
||||
- VerifyingStreamWrapper: Wrapper that verifies hash after streaming
|
||||
- verify_checksum(): Verify content against expected hash
|
||||
- ChecksumMismatchError: Exception for verification failures
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
import re
|
||||
import base64
|
||||
from typing import (
|
||||
Generator,
|
||||
Optional,
|
||||
Any,
|
||||
Callable,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Default chunk size for streaming operations (8KB)
|
||||
DEFAULT_CHUNK_SIZE = 8 * 1024
|
||||
|
||||
# Regex pattern for valid SHA256 hash (64 hex characters)
|
||||
SHA256_PATTERN = re.compile(r"^[a-fA-F0-9]{64}$")
|
||||
|
||||
|
||||
class ChecksumError(Exception):
|
||||
"""Base exception for checksum operations."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ChecksumMismatchError(ChecksumError):
|
||||
"""
|
||||
Raised when computed checksum does not match expected checksum.
|
||||
|
||||
Attributes:
|
||||
expected: The expected SHA256 hash
|
||||
actual: The actual computed SHA256 hash
|
||||
artifact_id: Optional artifact ID for context
|
||||
s3_key: Optional S3 key for debugging
|
||||
size: Optional file size
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
expected: str,
|
||||
actual: str,
|
||||
artifact_id: Optional[str] = None,
|
||||
s3_key: Optional[str] = None,
|
||||
size: Optional[int] = None,
|
||||
message: Optional[str] = None,
|
||||
):
|
||||
self.expected = expected
|
||||
self.actual = actual
|
||||
self.artifact_id = artifact_id
|
||||
self.s3_key = s3_key
|
||||
self.size = size
|
||||
|
||||
if message:
|
||||
self.message = message
|
||||
else:
|
||||
self.message = (
|
||||
f"Checksum verification failed: "
|
||||
f"expected {expected[:16]}..., got {actual[:16]}..."
|
||||
)
|
||||
super().__init__(self.message)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert to dictionary for logging/API responses."""
|
||||
return {
|
||||
"error": "checksum_mismatch",
|
||||
"expected": self.expected,
|
||||
"actual": self.actual,
|
||||
"artifact_id": self.artifact_id,
|
||||
"s3_key": self.s3_key,
|
||||
"size": self.size,
|
||||
"message": self.message,
|
||||
}
|
||||
|
||||
|
||||
class InvalidHashFormatError(ChecksumError):
|
||||
"""Raised when a hash string is not valid SHA256 format."""
|
||||
|
||||
def __init__(self, hash_value: str):
|
||||
self.hash_value = hash_value
|
||||
message = f"Invalid SHA256 hash format: '{hash_value[:32]}...'"
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
def is_valid_sha256(hash_value: str) -> bool:
|
||||
"""
|
||||
Check if a string is a valid SHA256 hash (64 hex characters).
|
||||
|
||||
Args:
|
||||
hash_value: String to validate
|
||||
|
||||
Returns:
|
||||
True if valid SHA256 format, False otherwise
|
||||
"""
|
||||
if not hash_value:
|
||||
return False
|
||||
return bool(SHA256_PATTERN.match(hash_value))
|
||||
|
||||
|
||||
def compute_sha256(content: bytes) -> str:
|
||||
"""
|
||||
Compute SHA256 hash of bytes content.
|
||||
|
||||
Args:
|
||||
content: Bytes content to hash
|
||||
|
||||
Returns:
|
||||
Lowercase hexadecimal SHA256 hash (64 characters)
|
||||
|
||||
Raises:
|
||||
ChecksumError: If hash computation fails
|
||||
"""
|
||||
if content is None:
|
||||
raise ChecksumError("Cannot compute hash of None content")
|
||||
|
||||
try:
|
||||
return hashlib.sha256(content).hexdigest().lower()
|
||||
except Exception as e:
|
||||
raise ChecksumError(f"Hash computation failed: {e}") from e
|
||||
|
||||
|
||||
def compute_sha256_stream(
|
||||
stream: Any,
|
||||
chunk_size: int = DEFAULT_CHUNK_SIZE,
|
||||
) -> str:
|
||||
"""
|
||||
Compute SHA256 hash from a stream or file-like object.
|
||||
|
||||
Reads the stream in chunks to minimize memory usage for large files.
|
||||
|
||||
Args:
|
||||
stream: Iterator yielding bytes or file-like object with read()
|
||||
chunk_size: Size of chunks to read (default 8KB)
|
||||
|
||||
Returns:
|
||||
Lowercase hexadecimal SHA256 hash (64 characters)
|
||||
|
||||
Raises:
|
||||
ChecksumError: If hash computation fails
|
||||
"""
|
||||
try:
|
||||
hasher = hashlib.sha256()
|
||||
|
||||
# Handle file-like objects with read()
|
||||
if hasattr(stream, "read"):
|
||||
while True:
|
||||
chunk = stream.read(chunk_size)
|
||||
if not chunk:
|
||||
break
|
||||
hasher.update(chunk)
|
||||
else:
|
||||
# Handle iterators
|
||||
for chunk in stream:
|
||||
if chunk:
|
||||
hasher.update(chunk)
|
||||
|
||||
return hasher.hexdigest().lower()
|
||||
except Exception as e:
|
||||
raise ChecksumError(f"Stream hash computation failed: {e}") from e
|
||||
|
||||
|
||||
def verify_checksum(content: bytes, expected: str) -> bool:
|
||||
"""
|
||||
Verify that content matches expected SHA256 hash.
|
||||
|
||||
Args:
|
||||
content: Bytes content to verify
|
||||
expected: Expected SHA256 hash (case-insensitive)
|
||||
|
||||
Returns:
|
||||
True if hash matches, False otherwise
|
||||
|
||||
Raises:
|
||||
InvalidHashFormatError: If expected hash is not valid format
|
||||
ChecksumError: If hash computation fails
|
||||
"""
|
||||
if not is_valid_sha256(expected):
|
||||
raise InvalidHashFormatError(expected)
|
||||
|
||||
actual = compute_sha256(content)
|
||||
return actual == expected.lower()
|
||||
|
||||
|
||||
def verify_checksum_strict(
|
||||
content: bytes,
|
||||
expected: str,
|
||||
artifact_id: Optional[str] = None,
|
||||
s3_key: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Verify content matches expected hash, raising exception on mismatch.
|
||||
|
||||
Args:
|
||||
content: Bytes content to verify
|
||||
expected: Expected SHA256 hash (case-insensitive)
|
||||
artifact_id: Optional artifact ID for error context
|
||||
s3_key: Optional S3 key for error context
|
||||
|
||||
Raises:
|
||||
InvalidHashFormatError: If expected hash is not valid format
|
||||
ChecksumMismatchError: If verification fails
|
||||
ChecksumError: If hash computation fails
|
||||
"""
|
||||
if not is_valid_sha256(expected):
|
||||
raise InvalidHashFormatError(expected)
|
||||
|
||||
actual = compute_sha256(content)
|
||||
if actual != expected.lower():
|
||||
raise ChecksumMismatchError(
|
||||
expected=expected.lower(),
|
||||
actual=actual,
|
||||
artifact_id=artifact_id,
|
||||
s3_key=s3_key,
|
||||
size=len(content),
|
||||
)
|
||||
|
||||
|
||||
def sha256_to_base64(hex_hash: str) -> str:
|
||||
"""
|
||||
Convert SHA256 hex string to base64 encoding (for RFC 3230 Digest header).
|
||||
|
||||
Args:
|
||||
hex_hash: SHA256 hash as 64-character hex string
|
||||
|
||||
Returns:
|
||||
Base64-encoded hash string
|
||||
"""
|
||||
if not is_valid_sha256(hex_hash):
|
||||
raise InvalidHashFormatError(hex_hash)
|
||||
|
||||
hash_bytes = bytes.fromhex(hex_hash)
|
||||
return base64.b64encode(hash_bytes).decode("ascii")
|
||||
|
||||
|
||||
class HashingStreamWrapper:
|
||||
"""
|
||||
Wrapper that computes SHA256 hash incrementally as chunks are read.
|
||||
|
||||
This allows computing the hash while streaming content to a client,
|
||||
without buffering the entire content in memory.
|
||||
|
||||
Usage:
|
||||
wrapper = HashingStreamWrapper(stream)
|
||||
for chunk in wrapper:
|
||||
send_to_client(chunk)
|
||||
final_hash = wrapper.get_hash()
|
||||
|
||||
Attributes:
|
||||
chunk_size: Size of chunks to yield
|
||||
bytes_read: Total bytes processed so far
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
stream: Any,
|
||||
chunk_size: int = DEFAULT_CHUNK_SIZE,
|
||||
):
|
||||
"""
|
||||
Initialize the hashing stream wrapper.
|
||||
|
||||
Args:
|
||||
stream: Source stream (iterator, file-like, or S3 StreamingBody)
|
||||
chunk_size: Size of chunks to yield (default 8KB)
|
||||
"""
|
||||
self._stream = stream
|
||||
self._hasher = hashlib.sha256()
|
||||
self._chunk_size = chunk_size
|
||||
self._bytes_read = 0
|
||||
self._finalized = False
|
||||
self._final_hash: Optional[str] = None
|
||||
|
||||
@property
|
||||
def bytes_read(self) -> int:
|
||||
"""Total bytes read so far."""
|
||||
return self._bytes_read
|
||||
|
||||
@property
|
||||
def chunk_size(self) -> int:
|
||||
"""Chunk size for reading."""
|
||||
return self._chunk_size
|
||||
|
||||
def __iter__(self) -> Generator[bytes, None, None]:
|
||||
"""Iterate over chunks, computing hash as we go."""
|
||||
# Handle S3 StreamingBody (has iter_chunks)
|
||||
if hasattr(self._stream, "iter_chunks"):
|
||||
for chunk in self._stream.iter_chunks(chunk_size=self._chunk_size):
|
||||
if chunk:
|
||||
self._hasher.update(chunk)
|
||||
self._bytes_read += len(chunk)
|
||||
yield chunk
|
||||
# Handle file-like objects with read()
|
||||
elif hasattr(self._stream, "read"):
|
||||
while True:
|
||||
chunk = self._stream.read(self._chunk_size)
|
||||
if not chunk:
|
||||
break
|
||||
self._hasher.update(chunk)
|
||||
self._bytes_read += len(chunk)
|
||||
yield chunk
|
||||
# Handle iterators
|
||||
else:
|
||||
for chunk in self._stream:
|
||||
if chunk:
|
||||
self._hasher.update(chunk)
|
||||
self._bytes_read += len(chunk)
|
||||
yield chunk
|
||||
|
||||
self._finalized = True
|
||||
self._final_hash = self._hasher.hexdigest().lower()
|
||||
|
||||
def get_hash(self) -> str:
|
||||
"""
|
||||
Get the computed SHA256 hash.
|
||||
|
||||
If stream hasn't been fully consumed, consumes remaining chunks.
|
||||
|
||||
Returns:
|
||||
Lowercase hexadecimal SHA256 hash
|
||||
"""
|
||||
if not self._finalized:
|
||||
# Consume remaining stream
|
||||
for _ in self:
|
||||
pass
|
||||
|
||||
return self._final_hash or self._hasher.hexdigest().lower()
|
||||
|
||||
def get_hash_if_complete(self) -> Optional[str]:
|
||||
"""
|
||||
Get hash only if stream has been fully consumed.
|
||||
|
||||
Returns:
|
||||
Hash if complete, None otherwise
|
||||
"""
|
||||
if self._finalized:
|
||||
return self._final_hash
|
||||
return None
|
||||
|
||||
|
||||
class VerifyingStreamWrapper:
|
||||
"""
|
||||
Wrapper that yields chunks and verifies hash after streaming completes.
|
||||
|
||||
IMPORTANT: Because HTTP streams cannot be "un-sent", if verification
|
||||
fails after streaming, the client has already received potentially
|
||||
corrupt data. This wrapper logs an error but cannot prevent delivery.
|
||||
|
||||
For guaranteed verification before delivery, use pre-verification mode
|
||||
which buffers the entire content first.
|
||||
|
||||
Usage:
|
||||
wrapper = VerifyingStreamWrapper(stream, expected_hash)
|
||||
for chunk in wrapper:
|
||||
send_to_client(chunk)
|
||||
wrapper.verify() # Raises ChecksumMismatchError if failed
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
stream: Any,
|
||||
expected_hash: str,
|
||||
artifact_id: Optional[str] = None,
|
||||
s3_key: Optional[str] = None,
|
||||
chunk_size: int = DEFAULT_CHUNK_SIZE,
|
||||
on_failure: Optional[Callable[[Any], None]] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the verifying stream wrapper.
|
||||
|
||||
Args:
|
||||
stream: Source stream
|
||||
expected_hash: Expected SHA256 hash to verify against
|
||||
artifact_id: Optional artifact ID for error context
|
||||
s3_key: Optional S3 key for error context
|
||||
chunk_size: Size of chunks to yield
|
||||
on_failure: Optional callback called on verification failure
|
||||
"""
|
||||
if not is_valid_sha256(expected_hash):
|
||||
raise InvalidHashFormatError(expected_hash)
|
||||
|
||||
self._hashing_wrapper = HashingStreamWrapper(stream, chunk_size)
|
||||
self._expected_hash = expected_hash.lower()
|
||||
self._artifact_id = artifact_id
|
||||
self._s3_key = s3_key
|
||||
self._on_failure = on_failure
|
||||
self._verified: Optional[bool] = None
|
||||
|
||||
@property
|
||||
def bytes_read(self) -> int:
|
||||
"""Total bytes read so far."""
|
||||
return self._hashing_wrapper.bytes_read
|
||||
|
||||
@property
|
||||
def is_verified(self) -> Optional[bool]:
|
||||
"""
|
||||
Verification status.
|
||||
|
||||
Returns:
|
||||
True if verified successfully, False if failed, None if not yet complete
|
||||
"""
|
||||
return self._verified
|
||||
|
||||
def __iter__(self) -> Generator[bytes, None, None]:
|
||||
"""Iterate over chunks."""
|
||||
yield from self._hashing_wrapper
|
||||
|
||||
def verify(self) -> bool:
|
||||
"""
|
||||
Verify the hash after stream is complete.
|
||||
|
||||
Must be called after fully consuming the iterator.
|
||||
|
||||
Returns:
|
||||
True if verification passed
|
||||
|
||||
Raises:
|
||||
ChecksumMismatchError: If verification failed
|
||||
"""
|
||||
actual_hash = self._hashing_wrapper.get_hash()
|
||||
|
||||
if actual_hash == self._expected_hash:
|
||||
self._verified = True
|
||||
logger.debug(
|
||||
f"Verification passed for {self._artifact_id or 'unknown'}: {actual_hash[:16]}..."
|
||||
)
|
||||
return True
|
||||
|
||||
self._verified = False
|
||||
error = ChecksumMismatchError(
|
||||
expected=self._expected_hash,
|
||||
actual=actual_hash,
|
||||
artifact_id=self._artifact_id,
|
||||
s3_key=self._s3_key,
|
||||
size=self._hashing_wrapper.bytes_read,
|
||||
)
|
||||
|
||||
# Log the failure
|
||||
logger.error(f"Checksum verification FAILED after streaming: {error.to_dict()}")
|
||||
|
||||
# Call failure callback if provided
|
||||
if self._on_failure:
|
||||
try:
|
||||
self._on_failure(error)
|
||||
except Exception as e:
|
||||
logger.warning(f"Verification failure callback raised exception: {e}")
|
||||
|
||||
raise error
|
||||
|
||||
def verify_silent(self) -> bool:
|
||||
"""
|
||||
Verify the hash without raising exception.
|
||||
|
||||
Returns:
|
||||
True if verification passed, False otherwise
|
||||
"""
|
||||
try:
|
||||
return self.verify()
|
||||
except ChecksumMismatchError:
|
||||
return False
|
||||
|
||||
def get_actual_hash(self) -> Optional[str]:
|
||||
"""Get the actual computed hash (only available after iteration)."""
|
||||
return self._hashing_wrapper.get_hash_if_complete()
|
||||
Reference in New Issue
Block a user