Files
orchard/backend/app/checksum.py

478 lines
14 KiB
Python

"""
Checksum utilities for download verification.
This module provides functions and classes for computing and verifying
SHA256 checksums during artifact downloads.
Key components:
- compute_sha256(): Compute SHA256 of bytes content
- compute_sha256_stream(): Compute SHA256 from an iterable stream
- HashingStreamWrapper: Wrapper that computes hash while streaming
- VerifyingStreamWrapper: Wrapper that verifies hash after streaming
- verify_checksum(): Verify content against expected hash
- ChecksumMismatchError: Exception for verification failures
"""
import hashlib
import logging
import re
import base64
from typing import (
Generator,
Optional,
Any,
Callable,
)
logger = logging.getLogger(__name__)
# Default chunk size for streaming operations (8KB)
DEFAULT_CHUNK_SIZE = 8 * 1024
# Regex pattern for valid SHA256 hash (64 hex characters)
SHA256_PATTERN = re.compile(r"^[a-fA-F0-9]{64}$")
class ChecksumError(Exception):
"""Base exception for checksum operations."""
pass
class ChecksumMismatchError(ChecksumError):
"""
Raised when computed checksum does not match expected checksum.
Attributes:
expected: The expected SHA256 hash
actual: The actual computed SHA256 hash
artifact_id: Optional artifact ID for context
s3_key: Optional S3 key for debugging
size: Optional file size
"""
def __init__(
self,
expected: str,
actual: str,
artifact_id: Optional[str] = None,
s3_key: Optional[str] = None,
size: Optional[int] = None,
message: Optional[str] = None,
):
self.expected = expected
self.actual = actual
self.artifact_id = artifact_id
self.s3_key = s3_key
self.size = size
if message:
self.message = message
else:
self.message = (
f"Checksum verification failed: "
f"expected {expected[:16]}..., got {actual[:16]}..."
)
super().__init__(self.message)
def to_dict(self) -> dict:
"""Convert to dictionary for logging/API responses."""
return {
"error": "checksum_mismatch",
"expected": self.expected,
"actual": self.actual,
"artifact_id": self.artifact_id,
"s3_key": self.s3_key,
"size": self.size,
"message": self.message,
}
class InvalidHashFormatError(ChecksumError):
"""Raised when a hash string is not valid SHA256 format."""
def __init__(self, hash_value: str):
self.hash_value = hash_value
message = f"Invalid SHA256 hash format: '{hash_value[:32]}...'"
super().__init__(message)
def is_valid_sha256(hash_value: str) -> bool:
"""
Check if a string is a valid SHA256 hash (64 hex characters).
Args:
hash_value: String to validate
Returns:
True if valid SHA256 format, False otherwise
"""
if not hash_value:
return False
return bool(SHA256_PATTERN.match(hash_value))
def compute_sha256(content: bytes) -> str:
"""
Compute SHA256 hash of bytes content.
Args:
content: Bytes content to hash
Returns:
Lowercase hexadecimal SHA256 hash (64 characters)
Raises:
ChecksumError: If hash computation fails
"""
if content is None:
raise ChecksumError("Cannot compute hash of None content")
try:
return hashlib.sha256(content).hexdigest().lower()
except Exception as e:
raise ChecksumError(f"Hash computation failed: {e}") from e
def compute_sha256_stream(
stream: Any,
chunk_size: int = DEFAULT_CHUNK_SIZE,
) -> str:
"""
Compute SHA256 hash from a stream or file-like object.
Reads the stream in chunks to minimize memory usage for large files.
Args:
stream: Iterator yielding bytes or file-like object with read()
chunk_size: Size of chunks to read (default 8KB)
Returns:
Lowercase hexadecimal SHA256 hash (64 characters)
Raises:
ChecksumError: If hash computation fails
"""
try:
hasher = hashlib.sha256()
# Handle file-like objects with read()
if hasattr(stream, "read"):
while True:
chunk = stream.read(chunk_size)
if not chunk:
break
hasher.update(chunk)
else:
# Handle iterators
for chunk in stream:
if chunk:
hasher.update(chunk)
return hasher.hexdigest().lower()
except Exception as e:
raise ChecksumError(f"Stream hash computation failed: {e}") from e
def verify_checksum(content: bytes, expected: str) -> bool:
"""
Verify that content matches expected SHA256 hash.
Args:
content: Bytes content to verify
expected: Expected SHA256 hash (case-insensitive)
Returns:
True if hash matches, False otherwise
Raises:
InvalidHashFormatError: If expected hash is not valid format
ChecksumError: If hash computation fails
"""
if not is_valid_sha256(expected):
raise InvalidHashFormatError(expected)
actual = compute_sha256(content)
return actual == expected.lower()
def verify_checksum_strict(
content: bytes,
expected: str,
artifact_id: Optional[str] = None,
s3_key: Optional[str] = None,
) -> None:
"""
Verify content matches expected hash, raising exception on mismatch.
Args:
content: Bytes content to verify
expected: Expected SHA256 hash (case-insensitive)
artifact_id: Optional artifact ID for error context
s3_key: Optional S3 key for error context
Raises:
InvalidHashFormatError: If expected hash is not valid format
ChecksumMismatchError: If verification fails
ChecksumError: If hash computation fails
"""
if not is_valid_sha256(expected):
raise InvalidHashFormatError(expected)
actual = compute_sha256(content)
if actual != expected.lower():
raise ChecksumMismatchError(
expected=expected.lower(),
actual=actual,
artifact_id=artifact_id,
s3_key=s3_key,
size=len(content),
)
def sha256_to_base64(hex_hash: str) -> str:
"""
Convert SHA256 hex string to base64 encoding (for RFC 3230 Digest header).
Args:
hex_hash: SHA256 hash as 64-character hex string
Returns:
Base64-encoded hash string
"""
if not is_valid_sha256(hex_hash):
raise InvalidHashFormatError(hex_hash)
hash_bytes = bytes.fromhex(hex_hash)
return base64.b64encode(hash_bytes).decode("ascii")
class HashingStreamWrapper:
"""
Wrapper that computes SHA256 hash incrementally as chunks are read.
This allows computing the hash while streaming content to a client,
without buffering the entire content in memory.
Usage:
wrapper = HashingStreamWrapper(stream)
for chunk in wrapper:
send_to_client(chunk)
final_hash = wrapper.get_hash()
Attributes:
chunk_size: Size of chunks to yield
bytes_read: Total bytes processed so far
"""
def __init__(
self,
stream: Any,
chunk_size: int = DEFAULT_CHUNK_SIZE,
):
"""
Initialize the hashing stream wrapper.
Args:
stream: Source stream (iterator, file-like, or S3 StreamingBody)
chunk_size: Size of chunks to yield (default 8KB)
"""
self._stream = stream
self._hasher = hashlib.sha256()
self._chunk_size = chunk_size
self._bytes_read = 0
self._finalized = False
self._final_hash: Optional[str] = None
@property
def bytes_read(self) -> int:
"""Total bytes read so far."""
return self._bytes_read
@property
def chunk_size(self) -> int:
"""Chunk size for reading."""
return self._chunk_size
def __iter__(self) -> Generator[bytes, None, None]:
"""Iterate over chunks, computing hash as we go."""
# Handle S3 StreamingBody (has iter_chunks)
if hasattr(self._stream, "iter_chunks"):
for chunk in self._stream.iter_chunks(chunk_size=self._chunk_size):
if chunk:
self._hasher.update(chunk)
self._bytes_read += len(chunk)
yield chunk
# Handle file-like objects with read()
elif hasattr(self._stream, "read"):
while True:
chunk = self._stream.read(self._chunk_size)
if not chunk:
break
self._hasher.update(chunk)
self._bytes_read += len(chunk)
yield chunk
# Handle iterators
else:
for chunk in self._stream:
if chunk:
self._hasher.update(chunk)
self._bytes_read += len(chunk)
yield chunk
self._finalized = True
self._final_hash = self._hasher.hexdigest().lower()
def get_hash(self) -> str:
"""
Get the computed SHA256 hash.
If stream hasn't been fully consumed, consumes remaining chunks.
Returns:
Lowercase hexadecimal SHA256 hash
"""
if not self._finalized:
# Consume remaining stream
for _ in self:
pass
return self._final_hash or self._hasher.hexdigest().lower()
def get_hash_if_complete(self) -> Optional[str]:
"""
Get hash only if stream has been fully consumed.
Returns:
Hash if complete, None otherwise
"""
if self._finalized:
return self._final_hash
return None
class VerifyingStreamWrapper:
"""
Wrapper that yields chunks and verifies hash after streaming completes.
IMPORTANT: Because HTTP streams cannot be "un-sent", if verification
fails after streaming, the client has already received potentially
corrupt data. This wrapper logs an error but cannot prevent delivery.
For guaranteed verification before delivery, use pre-verification mode
which buffers the entire content first.
Usage:
wrapper = VerifyingStreamWrapper(stream, expected_hash)
for chunk in wrapper:
send_to_client(chunk)
wrapper.verify() # Raises ChecksumMismatchError if failed
"""
def __init__(
self,
stream: Any,
expected_hash: str,
artifact_id: Optional[str] = None,
s3_key: Optional[str] = None,
chunk_size: int = DEFAULT_CHUNK_SIZE,
on_failure: Optional[Callable[[Any], None]] = None,
):
"""
Initialize the verifying stream wrapper.
Args:
stream: Source stream
expected_hash: Expected SHA256 hash to verify against
artifact_id: Optional artifact ID for error context
s3_key: Optional S3 key for error context
chunk_size: Size of chunks to yield
on_failure: Optional callback called on verification failure
"""
if not is_valid_sha256(expected_hash):
raise InvalidHashFormatError(expected_hash)
self._hashing_wrapper = HashingStreamWrapper(stream, chunk_size)
self._expected_hash = expected_hash.lower()
self._artifact_id = artifact_id
self._s3_key = s3_key
self._on_failure = on_failure
self._verified: Optional[bool] = None
@property
def bytes_read(self) -> int:
"""Total bytes read so far."""
return self._hashing_wrapper.bytes_read
@property
def is_verified(self) -> Optional[bool]:
"""
Verification status.
Returns:
True if verified successfully, False if failed, None if not yet complete
"""
return self._verified
def __iter__(self) -> Generator[bytes, None, None]:
"""Iterate over chunks."""
yield from self._hashing_wrapper
def verify(self) -> bool:
"""
Verify the hash after stream is complete.
Must be called after fully consuming the iterator.
Returns:
True if verification passed
Raises:
ChecksumMismatchError: If verification failed
"""
actual_hash = self._hashing_wrapper.get_hash()
if actual_hash == self._expected_hash:
self._verified = True
logger.debug(
f"Verification passed for {self._artifact_id or 'unknown'}: {actual_hash[:16]}..."
)
return True
self._verified = False
error = ChecksumMismatchError(
expected=self._expected_hash,
actual=actual_hash,
artifact_id=self._artifact_id,
s3_key=self._s3_key,
size=self._hashing_wrapper.bytes_read,
)
# Log the failure
logger.error(f"Checksum verification FAILED after streaming: {error.to_dict()}")
# Call failure callback if provided
if self._on_failure:
try:
self._on_failure(error)
except Exception as e:
logger.warning(f"Verification failure callback raised exception: {e}")
raise error
def verify_silent(self) -> bool:
"""
Verify the hash without raising exception.
Returns:
True if verification passed, False otherwise
"""
try:
return self.verify()
except ChecksumMismatchError:
return False
def get_actual_hash(self) -> Optional[str]:
"""Get the actual computed hash (only available after iteration)."""
return self._hashing_wrapper.get_hash_if_complete()