Files
orchard/backend/tests/test_checksum_verification.py

676 lines
22 KiB
Python

"""
Tests for checksum calculation, verification, and download verification.
This module tests:
- SHA256 hash computation (bytes and streams)
- HashingStreamWrapper incremental hashing
- VerifyingStreamWrapper with verification
- ChecksumMismatchError exception handling
- Download verification API endpoints
"""
import pytest
import hashlib
import io
from typing import Generator
from app.checksum import (
compute_sha256,
compute_sha256_stream,
verify_checksum,
verify_checksum_strict,
is_valid_sha256,
sha256_to_base64,
HashingStreamWrapper,
VerifyingStreamWrapper,
ChecksumMismatchError,
ChecksumError,
InvalidHashFormatError,
DEFAULT_CHUNK_SIZE,
)
# =============================================================================
# Test Data
# =============================================================================
# Known test vectors
TEST_CONTENT_HELLO = b"Hello, World!"
TEST_HASH_HELLO = "dffd6021bb2bd5b0af676290809ec3a53191dd81c7f70a4b28688a362182986f"
TEST_CONTENT_EMPTY = b""
TEST_HASH_EMPTY = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
TEST_CONTENT_BINARY = bytes(range(256))
TEST_HASH_BINARY = hashlib.sha256(TEST_CONTENT_BINARY).hexdigest()
# Invalid hashes for testing
INVALID_HASH_TOO_SHORT = "abcd1234"
INVALID_HASH_TOO_LONG = "a" * 65
INVALID_HASH_NON_HEX = "zzzz" + "a" * 60
INVALID_HASH_EMPTY = ""
# =============================================================================
# Unit Tests - SHA256 Computation
# =============================================================================
class TestComputeSHA256:
"""Tests for compute_sha256 function."""
def test_known_content_matches_expected_hash(self):
"""Test SHA256 of known content matches pre-computed hash."""
result = compute_sha256(TEST_CONTENT_HELLO)
assert result == TEST_HASH_HELLO
def test_returns_64_character_hex_string(self):
"""Test result is exactly 64 hex characters."""
result = compute_sha256(TEST_CONTENT_HELLO)
assert len(result) == 64
assert all(c in "0123456789abcdef" for c in result)
def test_returns_lowercase_hex(self):
"""Test result is lowercase."""
result = compute_sha256(TEST_CONTENT_HELLO)
assert result == result.lower()
def test_empty_content_returns_empty_hash(self):
"""Test empty bytes returns SHA256 of empty content."""
result = compute_sha256(TEST_CONTENT_EMPTY)
assert result == TEST_HASH_EMPTY
def test_deterministic_same_input_same_output(self):
"""Test same input always produces same output."""
content = b"test content for determinism"
result1 = compute_sha256(content)
result2 = compute_sha256(content)
assert result1 == result2
def test_different_content_different_hash(self):
"""Test different content produces different hash."""
hash1 = compute_sha256(b"content A")
hash2 = compute_sha256(b"content B")
assert hash1 != hash2
def test_single_bit_change_different_hash(self):
"""Test single bit change produces completely different hash."""
content1 = b"\x00" * 100
content2 = b"\x00" * 99 + b"\x01"
hash1 = compute_sha256(content1)
hash2 = compute_sha256(content2)
assert hash1 != hash2
def test_binary_content(self):
"""Test hashing binary content with all byte values."""
result = compute_sha256(TEST_CONTENT_BINARY)
assert result == TEST_HASH_BINARY
assert len(result) == 64
def test_large_content(self):
"""Test hashing larger content (1MB)."""
large_content = b"x" * (1024 * 1024)
result = compute_sha256(large_content)
expected = hashlib.sha256(large_content).hexdigest()
assert result == expected
def test_none_content_raises_error(self):
"""Test None content raises ChecksumError."""
with pytest.raises(ChecksumError, match="Cannot compute hash of None"):
compute_sha256(None)
class TestComputeSHA256Stream:
"""Tests for compute_sha256_stream function."""
def test_file_like_object(self):
"""Test hashing from file-like object."""
file_obj = io.BytesIO(TEST_CONTENT_HELLO)
result = compute_sha256_stream(file_obj)
assert result == TEST_HASH_HELLO
def test_iterator(self):
"""Test hashing from iterator of chunks."""
def chunk_iterator():
yield b"Hello, "
yield b"World!"
result = compute_sha256_stream(chunk_iterator())
assert result == TEST_HASH_HELLO
def test_various_chunk_sizes_same_result(self):
"""Test different chunk sizes produce same hash."""
content = b"x" * 10000
expected = hashlib.sha256(content).hexdigest()
for chunk_size in [1, 10, 100, 1000, 8192]:
file_obj = io.BytesIO(content)
result = compute_sha256_stream(file_obj, chunk_size=chunk_size)
assert result == expected, f"Failed for chunk_size={chunk_size}"
def test_single_byte_chunks(self):
"""Test with 1-byte chunks (edge case)."""
content = b"ABC"
file_obj = io.BytesIO(content)
result = compute_sha256_stream(file_obj, chunk_size=1)
expected = hashlib.sha256(content).hexdigest()
assert result == expected
def test_empty_stream(self):
"""Test empty stream returns empty content hash."""
file_obj = io.BytesIO(b"")
result = compute_sha256_stream(file_obj)
assert result == TEST_HASH_EMPTY
# =============================================================================
# Unit Tests - Hash Validation
# =============================================================================
class TestIsValidSHA256:
"""Tests for is_valid_sha256 function."""
def test_valid_hash_lowercase(self):
"""Test valid lowercase hash."""
assert is_valid_sha256(TEST_HASH_HELLO) is True
def test_valid_hash_uppercase(self):
"""Test valid uppercase hash."""
assert is_valid_sha256(TEST_HASH_HELLO.upper()) is True
def test_valid_hash_mixed_case(self):
"""Test valid mixed case hash."""
mixed = TEST_HASH_HELLO[:32].upper() + TEST_HASH_HELLO[32:].lower()
assert is_valid_sha256(mixed) is True
def test_invalid_too_short(self):
"""Test hash that's too short."""
assert is_valid_sha256(INVALID_HASH_TOO_SHORT) is False
def test_invalid_too_long(self):
"""Test hash that's too long."""
assert is_valid_sha256(INVALID_HASH_TOO_LONG) is False
def test_invalid_non_hex(self):
"""Test hash with non-hex characters."""
assert is_valid_sha256(INVALID_HASH_NON_HEX) is False
def test_invalid_empty(self):
"""Test empty string."""
assert is_valid_sha256(INVALID_HASH_EMPTY) is False
def test_invalid_none(self):
"""Test None value."""
assert is_valid_sha256(None) is False
class TestSHA256ToBase64:
"""Tests for sha256_to_base64 function."""
def test_converts_to_base64(self):
"""Test conversion to base64."""
result = sha256_to_base64(TEST_HASH_HELLO)
# Verify it's valid base64
import base64
decoded = base64.b64decode(result)
assert len(decoded) == 32 # SHA256 is 32 bytes
def test_invalid_hash_raises_error(self):
"""Test invalid hash raises InvalidHashFormatError."""
with pytest.raises(InvalidHashFormatError):
sha256_to_base64(INVALID_HASH_TOO_SHORT)
# =============================================================================
# Unit Tests - Verification Functions
# =============================================================================
class TestVerifyChecksum:
"""Tests for verify_checksum function."""
def test_matching_checksum_returns_true(self):
"""Test matching checksum returns True."""
result = verify_checksum(TEST_CONTENT_HELLO, TEST_HASH_HELLO)
assert result is True
def test_mismatched_checksum_returns_false(self):
"""Test mismatched checksum returns False."""
wrong_hash = "a" * 64
result = verify_checksum(TEST_CONTENT_HELLO, wrong_hash)
assert result is False
def test_case_insensitive_comparison(self):
"""Test comparison is case-insensitive."""
result = verify_checksum(TEST_CONTENT_HELLO, TEST_HASH_HELLO.upper())
assert result is True
def test_invalid_hash_format_raises_error(self):
"""Test invalid hash format raises error."""
with pytest.raises(InvalidHashFormatError):
verify_checksum(TEST_CONTENT_HELLO, INVALID_HASH_TOO_SHORT)
class TestVerifyChecksumStrict:
"""Tests for verify_checksum_strict function."""
def test_matching_checksum_returns_none(self):
"""Test matching checksum doesn't raise."""
# Should not raise
verify_checksum_strict(TEST_CONTENT_HELLO, TEST_HASH_HELLO)
def test_mismatched_checksum_raises_error(self):
"""Test mismatched checksum raises ChecksumMismatchError."""
wrong_hash = "a" * 64
with pytest.raises(ChecksumMismatchError) as exc_info:
verify_checksum_strict(TEST_CONTENT_HELLO, wrong_hash)
error = exc_info.value
assert error.expected == wrong_hash.lower()
assert error.actual == TEST_HASH_HELLO
assert error.size == len(TEST_CONTENT_HELLO)
def test_error_includes_context(self):
"""Test error includes artifact_id and s3_key context."""
wrong_hash = "a" * 64
with pytest.raises(ChecksumMismatchError) as exc_info:
verify_checksum_strict(
TEST_CONTENT_HELLO,
wrong_hash,
artifact_id="test-artifact-123",
s3_key="fruits/ab/cd/abcd1234...",
)
error = exc_info.value
assert error.artifact_id == "test-artifact-123"
assert error.s3_key == "fruits/ab/cd/abcd1234..."
# =============================================================================
# Unit Tests - HashingStreamWrapper
# =============================================================================
class TestHashingStreamWrapper:
"""Tests for HashingStreamWrapper class."""
def test_computes_correct_hash(self):
"""Test wrapper computes correct hash."""
stream = io.BytesIO(TEST_CONTENT_HELLO)
wrapper = HashingStreamWrapper(stream)
# Consume the stream
chunks = list(wrapper)
# Verify hash
assert wrapper.get_hash() == TEST_HASH_HELLO
def test_yields_correct_chunks(self):
"""Test wrapper yields all content."""
stream = io.BytesIO(TEST_CONTENT_HELLO)
wrapper = HashingStreamWrapper(stream)
chunks = list(wrapper)
content = b"".join(chunks)
assert content == TEST_CONTENT_HELLO
def test_tracks_bytes_read(self):
"""Test bytes_read property tracks correctly."""
stream = io.BytesIO(TEST_CONTENT_HELLO)
wrapper = HashingStreamWrapper(stream)
assert wrapper.bytes_read == 0
list(wrapper) # Consume
assert wrapper.bytes_read == len(TEST_CONTENT_HELLO)
def test_get_hash_before_iteration_consumes_stream(self):
"""Test get_hash() consumes stream if not already done."""
stream = io.BytesIO(TEST_CONTENT_HELLO)
wrapper = HashingStreamWrapper(stream)
# Call get_hash without iterating
hash_result = wrapper.get_hash()
assert hash_result == TEST_HASH_HELLO
assert wrapper.bytes_read == len(TEST_CONTENT_HELLO)
def test_get_hash_if_complete_before_iteration_returns_none(self):
"""Test get_hash_if_complete returns None before iteration."""
stream = io.BytesIO(TEST_CONTENT_HELLO)
wrapper = HashingStreamWrapper(stream)
assert wrapper.get_hash_if_complete() is None
def test_get_hash_if_complete_after_iteration_returns_hash(self):
"""Test get_hash_if_complete returns hash after iteration."""
stream = io.BytesIO(TEST_CONTENT_HELLO)
wrapper = HashingStreamWrapper(stream)
list(wrapper) # Consume
assert wrapper.get_hash_if_complete() == TEST_HASH_HELLO
def test_custom_chunk_size(self):
"""Test custom chunk size is respected."""
content = b"x" * 1000
stream = io.BytesIO(content)
wrapper = HashingStreamWrapper(stream, chunk_size=100)
chunks = list(wrapper)
# Each chunk should be at most 100 bytes
for chunk in chunks[:-1]: # All but last
assert len(chunk) == 100
# Total content should match
assert b"".join(chunks) == content
def test_iterator_interface(self):
"""Test wrapper supports iterator interface."""
stream = io.BytesIO(TEST_CONTENT_HELLO)
wrapper = HashingStreamWrapper(stream)
# Should be able to use for loop
result = b""
for chunk in wrapper:
result += chunk
assert result == TEST_CONTENT_HELLO
# =============================================================================
# Unit Tests - VerifyingStreamWrapper
# =============================================================================
class TestVerifyingStreamWrapper:
"""Tests for VerifyingStreamWrapper class."""
def test_verify_success(self):
"""Test verification succeeds for matching content."""
stream = io.BytesIO(TEST_CONTENT_HELLO)
wrapper = VerifyingStreamWrapper(stream, TEST_HASH_HELLO)
# Consume stream
list(wrapper)
# Verify should succeed
result = wrapper.verify()
assert result is True
assert wrapper.is_verified is True
def test_verify_failure_raises_error(self):
"""Test verification failure raises ChecksumMismatchError."""
stream = io.BytesIO(TEST_CONTENT_HELLO)
wrong_hash = "a" * 64
wrapper = VerifyingStreamWrapper(stream, wrong_hash)
# Consume stream
list(wrapper)
# Verify should fail
with pytest.raises(ChecksumMismatchError):
wrapper.verify()
assert wrapper.is_verified is False
def test_verify_silent_success(self):
"""Test verify_silent returns True on success."""
stream = io.BytesIO(TEST_CONTENT_HELLO)
wrapper = VerifyingStreamWrapper(stream, TEST_HASH_HELLO)
list(wrapper)
result = wrapper.verify_silent()
assert result is True
def test_verify_silent_failure(self):
"""Test verify_silent returns False on failure."""
stream = io.BytesIO(TEST_CONTENT_HELLO)
wrong_hash = "a" * 64
wrapper = VerifyingStreamWrapper(stream, wrong_hash)
list(wrapper)
result = wrapper.verify_silent()
assert result is False
def test_invalid_expected_hash_raises_error(self):
"""Test invalid expected hash raises error at construction."""
stream = io.BytesIO(TEST_CONTENT_HELLO)
with pytest.raises(InvalidHashFormatError):
VerifyingStreamWrapper(stream, INVALID_HASH_TOO_SHORT)
def test_on_failure_callback(self):
"""Test on_failure callback is called on verification failure."""
stream = io.BytesIO(TEST_CONTENT_HELLO)
wrong_hash = "a" * 64
callback_called = []
def on_failure(error):
callback_called.append(error)
wrapper = VerifyingStreamWrapper(stream, wrong_hash, on_failure=on_failure)
list(wrapper)
with pytest.raises(ChecksumMismatchError):
wrapper.verify()
assert len(callback_called) == 1
assert isinstance(callback_called[0], ChecksumMismatchError)
def test_get_actual_hash_after_iteration(self):
"""Test get_actual_hash returns hash after iteration."""
stream = io.BytesIO(TEST_CONTENT_HELLO)
wrapper = VerifyingStreamWrapper(stream, TEST_HASH_HELLO)
# Before iteration
assert wrapper.get_actual_hash() is None
list(wrapper)
# After iteration
assert wrapper.get_actual_hash() == TEST_HASH_HELLO
def test_includes_context_in_error(self):
"""Test error includes artifact_id and s3_key."""
stream = io.BytesIO(TEST_CONTENT_HELLO)
wrong_hash = "a" * 64
wrapper = VerifyingStreamWrapper(
stream,
wrong_hash,
artifact_id="test-artifact",
s3_key="test/key",
)
list(wrapper)
with pytest.raises(ChecksumMismatchError) as exc_info:
wrapper.verify()
error = exc_info.value
assert error.artifact_id == "test-artifact"
assert error.s3_key == "test/key"
# =============================================================================
# Unit Tests - ChecksumMismatchError
# =============================================================================
class TestChecksumMismatchError:
"""Tests for ChecksumMismatchError class."""
def test_to_dict(self):
"""Test to_dict returns proper dictionary."""
error = ChecksumMismatchError(
expected="a" * 64,
actual="b" * 64,
artifact_id="test-123",
s3_key="test/key",
size=1024,
)
result = error.to_dict()
assert result["error"] == "checksum_mismatch"
assert result["expected"] == "a" * 64
assert result["actual"] == "b" * 64
assert result["artifact_id"] == "test-123"
assert result["s3_key"] == "test/key"
assert result["size"] == 1024
def test_message_format(self):
"""Test error message format."""
error = ChecksumMismatchError(
expected="a" * 64,
actual="b" * 64,
)
assert "verification failed" in str(error).lower()
assert "expected" in str(error).lower()
def test_custom_message(self):
"""Test custom message is used."""
error = ChecksumMismatchError(
expected="a" * 64,
actual="b" * 64,
message="Custom error message",
)
assert str(error) == "Custom error message"
# =============================================================================
# Corruption Simulation Tests
# =============================================================================
class TestCorruptionDetection:
"""Tests for detecting corrupted content."""
def test_detect_truncated_content(self):
"""Test detection of truncated content."""
original = TEST_CONTENT_HELLO
truncated = original[:-1] # Remove last byte
original_hash = compute_sha256(original)
truncated_hash = compute_sha256(truncated)
assert original_hash != truncated_hash
assert verify_checksum(truncated, original_hash) is False
def test_detect_extra_bytes(self):
"""Test detection of content with extra bytes."""
original = TEST_CONTENT_HELLO
extended = original + b"\x00" # Add null byte
original_hash = compute_sha256(original)
assert verify_checksum(extended, original_hash) is False
def test_detect_single_bit_flip(self):
"""Test detection of single bit flip."""
original = TEST_CONTENT_HELLO
# Flip first bit of first byte
corrupted = bytes([original[0] ^ 0x01]) + original[1:]
original_hash = compute_sha256(original)
assert verify_checksum(corrupted, original_hash) is False
def test_detect_wrong_content(self):
"""Test detection of completely different content."""
original = TEST_CONTENT_HELLO
different = b"Something completely different"
original_hash = compute_sha256(original)
assert verify_checksum(different, original_hash) is False
def test_detect_empty_vs_nonempty(self):
"""Test detection of empty content vs non-empty."""
original = TEST_CONTENT_HELLO
empty = b""
original_hash = compute_sha256(original)
assert verify_checksum(empty, original_hash) is False
def test_streaming_detection_of_corruption(self):
"""Test VerifyingStreamWrapper detects corruption."""
original = b"Original content that will be corrupted"
original_hash = compute_sha256(original)
# Corrupt the content
corrupted = b"Corrupted content that is different"
stream = io.BytesIO(corrupted)
wrapper = VerifyingStreamWrapper(stream, original_hash)
list(wrapper) # Consume
with pytest.raises(ChecksumMismatchError):
wrapper.verify()
# =============================================================================
# Edge Case Tests
# =============================================================================
class TestEdgeCases:
"""Tests for edge cases and boundary conditions."""
def test_null_bytes_in_content(self):
"""Test content with null bytes."""
content = b"\x00\x00\x00"
hash_result = compute_sha256(content)
assert verify_checksum(content, hash_result) is True
def test_whitespace_only_content(self):
"""Test content with only whitespace."""
content = b" \t\n\r "
hash_result = compute_sha256(content)
assert verify_checksum(content, hash_result) is True
def test_large_content_streaming(self):
"""Test streaming verification of large content."""
# 1MB of content
large_content = b"x" * (1024 * 1024)
expected_hash = compute_sha256(large_content)
stream = io.BytesIO(large_content)
wrapper = VerifyingStreamWrapper(stream, expected_hash)
# Consume and verify
chunks = list(wrapper)
assert wrapper.verify() is True
assert b"".join(chunks) == large_content
def test_unicode_bytes_content(self):
"""Test content with unicode bytes."""
content = "Hello, 世界! 🌍".encode("utf-8")
hash_result = compute_sha256(content)
assert verify_checksum(content, hash_result) is True
def test_maximum_chunk_size_larger_than_content(self):
"""Test chunk size larger than content."""
content = b"small"
stream = io.BytesIO(content)
wrapper = HashingStreamWrapper(stream, chunk_size=1024 * 1024)
chunks = list(wrapper)
assert len(chunks) == 1
assert chunks[0] == content
assert wrapper.get_hash() == compute_sha256(content)