""" Checksum utilities for download verification. This module provides functions and classes for computing and verifying SHA256 checksums during artifact downloads. Key components: - compute_sha256(): Compute SHA256 of bytes content - compute_sha256_stream(): Compute SHA256 from an iterable stream - HashingStreamWrapper: Wrapper that computes hash while streaming - VerifyingStreamWrapper: Wrapper that verifies hash after streaming - verify_checksum(): Verify content against expected hash - ChecksumMismatchError: Exception for verification failures """ import hashlib import logging import re import base64 from typing import ( Generator, Optional, Any, Callable, ) logger = logging.getLogger(__name__) # Default chunk size for streaming operations (8KB) DEFAULT_CHUNK_SIZE = 8 * 1024 # Regex pattern for valid SHA256 hash (64 hex characters) SHA256_PATTERN = re.compile(r"^[a-fA-F0-9]{64}$") class ChecksumError(Exception): """Base exception for checksum operations.""" pass class ChecksumMismatchError(ChecksumError): """ Raised when computed checksum does not match expected checksum. Attributes: expected: The expected SHA256 hash actual: The actual computed SHA256 hash artifact_id: Optional artifact ID for context s3_key: Optional S3 key for debugging size: Optional file size """ def __init__( self, expected: str, actual: str, artifact_id: Optional[str] = None, s3_key: Optional[str] = None, size: Optional[int] = None, message: Optional[str] = None, ): self.expected = expected self.actual = actual self.artifact_id = artifact_id self.s3_key = s3_key self.size = size if message: self.message = message else: self.message = ( f"Checksum verification failed: " f"expected {expected[:16]}..., got {actual[:16]}..." ) super().__init__(self.message) def to_dict(self) -> dict: """Convert to dictionary for logging/API responses.""" return { "error": "checksum_mismatch", "expected": self.expected, "actual": self.actual, "artifact_id": self.artifact_id, "s3_key": self.s3_key, "size": self.size, "message": self.message, } class InvalidHashFormatError(ChecksumError): """Raised when a hash string is not valid SHA256 format.""" def __init__(self, hash_value: str): self.hash_value = hash_value message = f"Invalid SHA256 hash format: '{hash_value[:32]}...'" super().__init__(message) def is_valid_sha256(hash_value: str) -> bool: """ Check if a string is a valid SHA256 hash (64 hex characters). Args: hash_value: String to validate Returns: True if valid SHA256 format, False otherwise """ if not hash_value: return False return bool(SHA256_PATTERN.match(hash_value)) def compute_sha256(content: bytes) -> str: """ Compute SHA256 hash of bytes content. Args: content: Bytes content to hash Returns: Lowercase hexadecimal SHA256 hash (64 characters) Raises: ChecksumError: If hash computation fails """ if content is None: raise ChecksumError("Cannot compute hash of None content") try: return hashlib.sha256(content).hexdigest().lower() except Exception as e: raise ChecksumError(f"Hash computation failed: {e}") from e def compute_sha256_stream( stream: Any, chunk_size: int = DEFAULT_CHUNK_SIZE, ) -> str: """ Compute SHA256 hash from a stream or file-like object. Reads the stream in chunks to minimize memory usage for large files. Args: stream: Iterator yielding bytes or file-like object with read() chunk_size: Size of chunks to read (default 8KB) Returns: Lowercase hexadecimal SHA256 hash (64 characters) Raises: ChecksumError: If hash computation fails """ try: hasher = hashlib.sha256() # Handle file-like objects with read() if hasattr(stream, "read"): while True: chunk = stream.read(chunk_size) if not chunk: break hasher.update(chunk) else: # Handle iterators for chunk in stream: if chunk: hasher.update(chunk) return hasher.hexdigest().lower() except Exception as e: raise ChecksumError(f"Stream hash computation failed: {e}") from e def verify_checksum(content: bytes, expected: str) -> bool: """ Verify that content matches expected SHA256 hash. Args: content: Bytes content to verify expected: Expected SHA256 hash (case-insensitive) Returns: True if hash matches, False otherwise Raises: InvalidHashFormatError: If expected hash is not valid format ChecksumError: If hash computation fails """ if not is_valid_sha256(expected): raise InvalidHashFormatError(expected) actual = compute_sha256(content) return actual == expected.lower() def verify_checksum_strict( content: bytes, expected: str, artifact_id: Optional[str] = None, s3_key: Optional[str] = None, ) -> None: """ Verify content matches expected hash, raising exception on mismatch. Args: content: Bytes content to verify expected: Expected SHA256 hash (case-insensitive) artifact_id: Optional artifact ID for error context s3_key: Optional S3 key for error context Raises: InvalidHashFormatError: If expected hash is not valid format ChecksumMismatchError: If verification fails ChecksumError: If hash computation fails """ if not is_valid_sha256(expected): raise InvalidHashFormatError(expected) actual = compute_sha256(content) if actual != expected.lower(): raise ChecksumMismatchError( expected=expected.lower(), actual=actual, artifact_id=artifact_id, s3_key=s3_key, size=len(content), ) def sha256_to_base64(hex_hash: str) -> str: """ Convert SHA256 hex string to base64 encoding (for RFC 3230 Digest header). Args: hex_hash: SHA256 hash as 64-character hex string Returns: Base64-encoded hash string """ if not is_valid_sha256(hex_hash): raise InvalidHashFormatError(hex_hash) hash_bytes = bytes.fromhex(hex_hash) return base64.b64encode(hash_bytes).decode("ascii") class HashingStreamWrapper: """ Wrapper that computes SHA256 hash incrementally as chunks are read. This allows computing the hash while streaming content to a client, without buffering the entire content in memory. Usage: wrapper = HashingStreamWrapper(stream) for chunk in wrapper: send_to_client(chunk) final_hash = wrapper.get_hash() Attributes: chunk_size: Size of chunks to yield bytes_read: Total bytes processed so far """ def __init__( self, stream: Any, chunk_size: int = DEFAULT_CHUNK_SIZE, ): """ Initialize the hashing stream wrapper. Args: stream: Source stream (iterator, file-like, or S3 StreamingBody) chunk_size: Size of chunks to yield (default 8KB) """ self._stream = stream self._hasher = hashlib.sha256() self._chunk_size = chunk_size self._bytes_read = 0 self._finalized = False self._final_hash: Optional[str] = None @property def bytes_read(self) -> int: """Total bytes read so far.""" return self._bytes_read @property def chunk_size(self) -> int: """Chunk size for reading.""" return self._chunk_size def __iter__(self) -> Generator[bytes, None, None]: """Iterate over chunks, computing hash as we go.""" # Handle S3 StreamingBody (has iter_chunks) if hasattr(self._stream, "iter_chunks"): for chunk in self._stream.iter_chunks(chunk_size=self._chunk_size): if chunk: self._hasher.update(chunk) self._bytes_read += len(chunk) yield chunk # Handle file-like objects with read() elif hasattr(self._stream, "read"): while True: chunk = self._stream.read(self._chunk_size) if not chunk: break self._hasher.update(chunk) self._bytes_read += len(chunk) yield chunk # Handle iterators else: for chunk in self._stream: if chunk: self._hasher.update(chunk) self._bytes_read += len(chunk) yield chunk self._finalized = True self._final_hash = self._hasher.hexdigest().lower() def get_hash(self) -> str: """ Get the computed SHA256 hash. If stream hasn't been fully consumed, consumes remaining chunks. Returns: Lowercase hexadecimal SHA256 hash """ if not self._finalized: # Consume remaining stream for _ in self: pass return self._final_hash or self._hasher.hexdigest().lower() def get_hash_if_complete(self) -> Optional[str]: """ Get hash only if stream has been fully consumed. Returns: Hash if complete, None otherwise """ if self._finalized: return self._final_hash return None class VerifyingStreamWrapper: """ Wrapper that yields chunks and verifies hash after streaming completes. IMPORTANT: Because HTTP streams cannot be "un-sent", if verification fails after streaming, the client has already received potentially corrupt data. This wrapper logs an error but cannot prevent delivery. For guaranteed verification before delivery, use pre-verification mode which buffers the entire content first. Usage: wrapper = VerifyingStreamWrapper(stream, expected_hash) for chunk in wrapper: send_to_client(chunk) wrapper.verify() # Raises ChecksumMismatchError if failed """ def __init__( self, stream: Any, expected_hash: str, artifact_id: Optional[str] = None, s3_key: Optional[str] = None, chunk_size: int = DEFAULT_CHUNK_SIZE, on_failure: Optional[Callable[[Any], None]] = None, ): """ Initialize the verifying stream wrapper. Args: stream: Source stream expected_hash: Expected SHA256 hash to verify against artifact_id: Optional artifact ID for error context s3_key: Optional S3 key for error context chunk_size: Size of chunks to yield on_failure: Optional callback called on verification failure """ if not is_valid_sha256(expected_hash): raise InvalidHashFormatError(expected_hash) self._hashing_wrapper = HashingStreamWrapper(stream, chunk_size) self._expected_hash = expected_hash.lower() self._artifact_id = artifact_id self._s3_key = s3_key self._on_failure = on_failure self._verified: Optional[bool] = None @property def bytes_read(self) -> int: """Total bytes read so far.""" return self._hashing_wrapper.bytes_read @property def is_verified(self) -> Optional[bool]: """ Verification status. Returns: True if verified successfully, False if failed, None if not yet complete """ return self._verified def __iter__(self) -> Generator[bytes, None, None]: """Iterate over chunks.""" yield from self._hashing_wrapper def verify(self) -> bool: """ Verify the hash after stream is complete. Must be called after fully consuming the iterator. Returns: True if verification passed Raises: ChecksumMismatchError: If verification failed """ actual_hash = self._hashing_wrapper.get_hash() if actual_hash == self._expected_hash: self._verified = True logger.debug( f"Verification passed for {self._artifact_id or 'unknown'}: {actual_hash[:16]}..." ) return True self._verified = False error = ChecksumMismatchError( expected=self._expected_hash, actual=actual_hash, artifact_id=self._artifact_id, s3_key=self._s3_key, size=self._hashing_wrapper.bytes_read, ) # Log the failure logger.error(f"Checksum verification FAILED after streaming: {error.to_dict()}") # Call failure callback if provided if self._on_failure: try: self._on_failure(error) except Exception as e: logger.warning(f"Verification failure callback raised exception: {e}") raise error def verify_silent(self) -> bool: """ Verify the hash without raising exception. Returns: True if verification passed, False otherwise """ try: return self.verify() except ChecksumMismatchError: return False def get_actual_hash(self) -> Optional[str]: """Get the actual computed hash (only available after iteration).""" return self._hashing_wrapper.get_hash_if_complete()