478 lines
14 KiB
Python
478 lines
14 KiB
Python
"""
|
|
Checksum utilities for download verification.
|
|
|
|
This module provides functions and classes for computing and verifying
|
|
SHA256 checksums during artifact downloads.
|
|
|
|
Key components:
|
|
- compute_sha256(): Compute SHA256 of bytes content
|
|
- compute_sha256_stream(): Compute SHA256 from an iterable stream
|
|
- HashingStreamWrapper: Wrapper that computes hash while streaming
|
|
- VerifyingStreamWrapper: Wrapper that verifies hash after streaming
|
|
- verify_checksum(): Verify content against expected hash
|
|
- ChecksumMismatchError: Exception for verification failures
|
|
"""
|
|
|
|
import hashlib
|
|
import logging
|
|
import re
|
|
import base64
|
|
from typing import (
|
|
Generator,
|
|
Optional,
|
|
Any,
|
|
Callable,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Default chunk size for streaming operations (8KB)
|
|
DEFAULT_CHUNK_SIZE = 8 * 1024
|
|
|
|
# Regex pattern for valid SHA256 hash (64 hex characters)
|
|
SHA256_PATTERN = re.compile(r"^[a-fA-F0-9]{64}$")
|
|
|
|
|
|
class ChecksumError(Exception):
|
|
"""Base exception for checksum operations."""
|
|
|
|
pass
|
|
|
|
|
|
class ChecksumMismatchError(ChecksumError):
|
|
"""
|
|
Raised when computed checksum does not match expected checksum.
|
|
|
|
Attributes:
|
|
expected: The expected SHA256 hash
|
|
actual: The actual computed SHA256 hash
|
|
artifact_id: Optional artifact ID for context
|
|
s3_key: Optional S3 key for debugging
|
|
size: Optional file size
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
expected: str,
|
|
actual: str,
|
|
artifact_id: Optional[str] = None,
|
|
s3_key: Optional[str] = None,
|
|
size: Optional[int] = None,
|
|
message: Optional[str] = None,
|
|
):
|
|
self.expected = expected
|
|
self.actual = actual
|
|
self.artifact_id = artifact_id
|
|
self.s3_key = s3_key
|
|
self.size = size
|
|
|
|
if message:
|
|
self.message = message
|
|
else:
|
|
self.message = (
|
|
f"Checksum verification failed: "
|
|
f"expected {expected[:16]}..., got {actual[:16]}..."
|
|
)
|
|
super().__init__(self.message)
|
|
|
|
def to_dict(self) -> dict:
|
|
"""Convert to dictionary for logging/API responses."""
|
|
return {
|
|
"error": "checksum_mismatch",
|
|
"expected": self.expected,
|
|
"actual": self.actual,
|
|
"artifact_id": self.artifact_id,
|
|
"s3_key": self.s3_key,
|
|
"size": self.size,
|
|
"message": self.message,
|
|
}
|
|
|
|
|
|
class InvalidHashFormatError(ChecksumError):
|
|
"""Raised when a hash string is not valid SHA256 format."""
|
|
|
|
def __init__(self, hash_value: str):
|
|
self.hash_value = hash_value
|
|
message = f"Invalid SHA256 hash format: '{hash_value[:32]}...'"
|
|
super().__init__(message)
|
|
|
|
|
|
def is_valid_sha256(hash_value: str) -> bool:
|
|
"""
|
|
Check if a string is a valid SHA256 hash (64 hex characters).
|
|
|
|
Args:
|
|
hash_value: String to validate
|
|
|
|
Returns:
|
|
True if valid SHA256 format, False otherwise
|
|
"""
|
|
if not hash_value:
|
|
return False
|
|
return bool(SHA256_PATTERN.match(hash_value))
|
|
|
|
|
|
def compute_sha256(content: bytes) -> str:
|
|
"""
|
|
Compute SHA256 hash of bytes content.
|
|
|
|
Args:
|
|
content: Bytes content to hash
|
|
|
|
Returns:
|
|
Lowercase hexadecimal SHA256 hash (64 characters)
|
|
|
|
Raises:
|
|
ChecksumError: If hash computation fails
|
|
"""
|
|
if content is None:
|
|
raise ChecksumError("Cannot compute hash of None content")
|
|
|
|
try:
|
|
return hashlib.sha256(content).hexdigest().lower()
|
|
except Exception as e:
|
|
raise ChecksumError(f"Hash computation failed: {e}") from e
|
|
|
|
|
|
def compute_sha256_stream(
|
|
stream: Any,
|
|
chunk_size: int = DEFAULT_CHUNK_SIZE,
|
|
) -> str:
|
|
"""
|
|
Compute SHA256 hash from a stream or file-like object.
|
|
|
|
Reads the stream in chunks to minimize memory usage for large files.
|
|
|
|
Args:
|
|
stream: Iterator yielding bytes or file-like object with read()
|
|
chunk_size: Size of chunks to read (default 8KB)
|
|
|
|
Returns:
|
|
Lowercase hexadecimal SHA256 hash (64 characters)
|
|
|
|
Raises:
|
|
ChecksumError: If hash computation fails
|
|
"""
|
|
try:
|
|
hasher = hashlib.sha256()
|
|
|
|
# Handle file-like objects with read()
|
|
if hasattr(stream, "read"):
|
|
while True:
|
|
chunk = stream.read(chunk_size)
|
|
if not chunk:
|
|
break
|
|
hasher.update(chunk)
|
|
else:
|
|
# Handle iterators
|
|
for chunk in stream:
|
|
if chunk:
|
|
hasher.update(chunk)
|
|
|
|
return hasher.hexdigest().lower()
|
|
except Exception as e:
|
|
raise ChecksumError(f"Stream hash computation failed: {e}") from e
|
|
|
|
|
|
def verify_checksum(content: bytes, expected: str) -> bool:
|
|
"""
|
|
Verify that content matches expected SHA256 hash.
|
|
|
|
Args:
|
|
content: Bytes content to verify
|
|
expected: Expected SHA256 hash (case-insensitive)
|
|
|
|
Returns:
|
|
True if hash matches, False otherwise
|
|
|
|
Raises:
|
|
InvalidHashFormatError: If expected hash is not valid format
|
|
ChecksumError: If hash computation fails
|
|
"""
|
|
if not is_valid_sha256(expected):
|
|
raise InvalidHashFormatError(expected)
|
|
|
|
actual = compute_sha256(content)
|
|
return actual == expected.lower()
|
|
|
|
|
|
def verify_checksum_strict(
|
|
content: bytes,
|
|
expected: str,
|
|
artifact_id: Optional[str] = None,
|
|
s3_key: Optional[str] = None,
|
|
) -> None:
|
|
"""
|
|
Verify content matches expected hash, raising exception on mismatch.
|
|
|
|
Args:
|
|
content: Bytes content to verify
|
|
expected: Expected SHA256 hash (case-insensitive)
|
|
artifact_id: Optional artifact ID for error context
|
|
s3_key: Optional S3 key for error context
|
|
|
|
Raises:
|
|
InvalidHashFormatError: If expected hash is not valid format
|
|
ChecksumMismatchError: If verification fails
|
|
ChecksumError: If hash computation fails
|
|
"""
|
|
if not is_valid_sha256(expected):
|
|
raise InvalidHashFormatError(expected)
|
|
|
|
actual = compute_sha256(content)
|
|
if actual != expected.lower():
|
|
raise ChecksumMismatchError(
|
|
expected=expected.lower(),
|
|
actual=actual,
|
|
artifact_id=artifact_id,
|
|
s3_key=s3_key,
|
|
size=len(content),
|
|
)
|
|
|
|
|
|
def sha256_to_base64(hex_hash: str) -> str:
|
|
"""
|
|
Convert SHA256 hex string to base64 encoding (for RFC 3230 Digest header).
|
|
|
|
Args:
|
|
hex_hash: SHA256 hash as 64-character hex string
|
|
|
|
Returns:
|
|
Base64-encoded hash string
|
|
"""
|
|
if not is_valid_sha256(hex_hash):
|
|
raise InvalidHashFormatError(hex_hash)
|
|
|
|
hash_bytes = bytes.fromhex(hex_hash)
|
|
return base64.b64encode(hash_bytes).decode("ascii")
|
|
|
|
|
|
class HashingStreamWrapper:
|
|
"""
|
|
Wrapper that computes SHA256 hash incrementally as chunks are read.
|
|
|
|
This allows computing the hash while streaming content to a client,
|
|
without buffering the entire content in memory.
|
|
|
|
Usage:
|
|
wrapper = HashingStreamWrapper(stream)
|
|
for chunk in wrapper:
|
|
send_to_client(chunk)
|
|
final_hash = wrapper.get_hash()
|
|
|
|
Attributes:
|
|
chunk_size: Size of chunks to yield
|
|
bytes_read: Total bytes processed so far
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
stream: Any,
|
|
chunk_size: int = DEFAULT_CHUNK_SIZE,
|
|
):
|
|
"""
|
|
Initialize the hashing stream wrapper.
|
|
|
|
Args:
|
|
stream: Source stream (iterator, file-like, or S3 StreamingBody)
|
|
chunk_size: Size of chunks to yield (default 8KB)
|
|
"""
|
|
self._stream = stream
|
|
self._hasher = hashlib.sha256()
|
|
self._chunk_size = chunk_size
|
|
self._bytes_read = 0
|
|
self._finalized = False
|
|
self._final_hash: Optional[str] = None
|
|
|
|
@property
|
|
def bytes_read(self) -> int:
|
|
"""Total bytes read so far."""
|
|
return self._bytes_read
|
|
|
|
@property
|
|
def chunk_size(self) -> int:
|
|
"""Chunk size for reading."""
|
|
return self._chunk_size
|
|
|
|
def __iter__(self) -> Generator[bytes, None, None]:
|
|
"""Iterate over chunks, computing hash as we go."""
|
|
# Handle S3 StreamingBody (has iter_chunks)
|
|
if hasattr(self._stream, "iter_chunks"):
|
|
for chunk in self._stream.iter_chunks(chunk_size=self._chunk_size):
|
|
if chunk:
|
|
self._hasher.update(chunk)
|
|
self._bytes_read += len(chunk)
|
|
yield chunk
|
|
# Handle file-like objects with read()
|
|
elif hasattr(self._stream, "read"):
|
|
while True:
|
|
chunk = self._stream.read(self._chunk_size)
|
|
if not chunk:
|
|
break
|
|
self._hasher.update(chunk)
|
|
self._bytes_read += len(chunk)
|
|
yield chunk
|
|
# Handle iterators
|
|
else:
|
|
for chunk in self._stream:
|
|
if chunk:
|
|
self._hasher.update(chunk)
|
|
self._bytes_read += len(chunk)
|
|
yield chunk
|
|
|
|
self._finalized = True
|
|
self._final_hash = self._hasher.hexdigest().lower()
|
|
|
|
def get_hash(self) -> str:
|
|
"""
|
|
Get the computed SHA256 hash.
|
|
|
|
If stream hasn't been fully consumed, consumes remaining chunks.
|
|
|
|
Returns:
|
|
Lowercase hexadecimal SHA256 hash
|
|
"""
|
|
if not self._finalized:
|
|
# Consume remaining stream
|
|
for _ in self:
|
|
pass
|
|
|
|
return self._final_hash or self._hasher.hexdigest().lower()
|
|
|
|
def get_hash_if_complete(self) -> Optional[str]:
|
|
"""
|
|
Get hash only if stream has been fully consumed.
|
|
|
|
Returns:
|
|
Hash if complete, None otherwise
|
|
"""
|
|
if self._finalized:
|
|
return self._final_hash
|
|
return None
|
|
|
|
|
|
class VerifyingStreamWrapper:
|
|
"""
|
|
Wrapper that yields chunks and verifies hash after streaming completes.
|
|
|
|
IMPORTANT: Because HTTP streams cannot be "un-sent", if verification
|
|
fails after streaming, the client has already received potentially
|
|
corrupt data. This wrapper logs an error but cannot prevent delivery.
|
|
|
|
For guaranteed verification before delivery, use pre-verification mode
|
|
which buffers the entire content first.
|
|
|
|
Usage:
|
|
wrapper = VerifyingStreamWrapper(stream, expected_hash)
|
|
for chunk in wrapper:
|
|
send_to_client(chunk)
|
|
wrapper.verify() # Raises ChecksumMismatchError if failed
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
stream: Any,
|
|
expected_hash: str,
|
|
artifact_id: Optional[str] = None,
|
|
s3_key: Optional[str] = None,
|
|
chunk_size: int = DEFAULT_CHUNK_SIZE,
|
|
on_failure: Optional[Callable[[Any], None]] = None,
|
|
):
|
|
"""
|
|
Initialize the verifying stream wrapper.
|
|
|
|
Args:
|
|
stream: Source stream
|
|
expected_hash: Expected SHA256 hash to verify against
|
|
artifact_id: Optional artifact ID for error context
|
|
s3_key: Optional S3 key for error context
|
|
chunk_size: Size of chunks to yield
|
|
on_failure: Optional callback called on verification failure
|
|
"""
|
|
if not is_valid_sha256(expected_hash):
|
|
raise InvalidHashFormatError(expected_hash)
|
|
|
|
self._hashing_wrapper = HashingStreamWrapper(stream, chunk_size)
|
|
self._expected_hash = expected_hash.lower()
|
|
self._artifact_id = artifact_id
|
|
self._s3_key = s3_key
|
|
self._on_failure = on_failure
|
|
self._verified: Optional[bool] = None
|
|
|
|
@property
|
|
def bytes_read(self) -> int:
|
|
"""Total bytes read so far."""
|
|
return self._hashing_wrapper.bytes_read
|
|
|
|
@property
|
|
def is_verified(self) -> Optional[bool]:
|
|
"""
|
|
Verification status.
|
|
|
|
Returns:
|
|
True if verified successfully, False if failed, None if not yet complete
|
|
"""
|
|
return self._verified
|
|
|
|
def __iter__(self) -> Generator[bytes, None, None]:
|
|
"""Iterate over chunks."""
|
|
yield from self._hashing_wrapper
|
|
|
|
def verify(self) -> bool:
|
|
"""
|
|
Verify the hash after stream is complete.
|
|
|
|
Must be called after fully consuming the iterator.
|
|
|
|
Returns:
|
|
True if verification passed
|
|
|
|
Raises:
|
|
ChecksumMismatchError: If verification failed
|
|
"""
|
|
actual_hash = self._hashing_wrapper.get_hash()
|
|
|
|
if actual_hash == self._expected_hash:
|
|
self._verified = True
|
|
logger.debug(
|
|
f"Verification passed for {self._artifact_id or 'unknown'}: {actual_hash[:16]}..."
|
|
)
|
|
return True
|
|
|
|
self._verified = False
|
|
error = ChecksumMismatchError(
|
|
expected=self._expected_hash,
|
|
actual=actual_hash,
|
|
artifact_id=self._artifact_id,
|
|
s3_key=self._s3_key,
|
|
size=self._hashing_wrapper.bytes_read,
|
|
)
|
|
|
|
# Log the failure
|
|
logger.error(f"Checksum verification FAILED after streaming: {error.to_dict()}")
|
|
|
|
# Call failure callback if provided
|
|
if self._on_failure:
|
|
try:
|
|
self._on_failure(error)
|
|
except Exception as e:
|
|
logger.warning(f"Verification failure callback raised exception: {e}")
|
|
|
|
raise error
|
|
|
|
def verify_silent(self) -> bool:
|
|
"""
|
|
Verify the hash without raising exception.
|
|
|
|
Returns:
|
|
True if verification passed, False otherwise
|
|
"""
|
|
try:
|
|
return self.verify()
|
|
except ChecksumMismatchError:
|
|
return False
|
|
|
|
def get_actual_hash(self) -> Optional[str]:
|
|
"""Get the actual computed hash (only available after iteration)."""
|
|
return self._hashing_wrapper.get_hash_if_complete()
|