""" Tests for checksum calculation, verification, and download verification. This module tests: - SHA256 hash computation (bytes and streams) - HashingStreamWrapper incremental hashing - VerifyingStreamWrapper with verification - ChecksumMismatchError exception handling - Download verification API endpoints """ import pytest import hashlib import io from typing import Generator from app.checksum import ( compute_sha256, compute_sha256_stream, verify_checksum, verify_checksum_strict, is_valid_sha256, sha256_to_base64, HashingStreamWrapper, VerifyingStreamWrapper, ChecksumMismatchError, ChecksumError, InvalidHashFormatError, DEFAULT_CHUNK_SIZE, ) # ============================================================================= # Test Data # ============================================================================= # Known test vectors TEST_CONTENT_HELLO = b"Hello, World!" TEST_HASH_HELLO = "dffd6021bb2bd5b0af676290809ec3a53191dd81c7f70a4b28688a362182986f" TEST_CONTENT_EMPTY = b"" TEST_HASH_EMPTY = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" TEST_CONTENT_BINARY = bytes(range(256)) TEST_HASH_BINARY = hashlib.sha256(TEST_CONTENT_BINARY).hexdigest() # Invalid hashes for testing INVALID_HASH_TOO_SHORT = "abcd1234" INVALID_HASH_TOO_LONG = "a" * 65 INVALID_HASH_NON_HEX = "zzzz" + "a" * 60 INVALID_HASH_EMPTY = "" # ============================================================================= # Unit Tests - SHA256 Computation # ============================================================================= class TestComputeSHA256: """Tests for compute_sha256 function.""" def test_known_content_matches_expected_hash(self): """Test SHA256 of known content matches pre-computed hash.""" result = compute_sha256(TEST_CONTENT_HELLO) assert result == TEST_HASH_HELLO def test_returns_64_character_hex_string(self): """Test result is exactly 64 hex characters.""" result = compute_sha256(TEST_CONTENT_HELLO) assert len(result) == 64 assert all(c in "0123456789abcdef" for c in result) def test_returns_lowercase_hex(self): """Test result is lowercase.""" result = compute_sha256(TEST_CONTENT_HELLO) assert result == result.lower() def test_empty_content_returns_empty_hash(self): """Test empty bytes returns SHA256 of empty content.""" result = compute_sha256(TEST_CONTENT_EMPTY) assert result == TEST_HASH_EMPTY def test_deterministic_same_input_same_output(self): """Test same input always produces same output.""" content = b"test content for determinism" result1 = compute_sha256(content) result2 = compute_sha256(content) assert result1 == result2 def test_different_content_different_hash(self): """Test different content produces different hash.""" hash1 = compute_sha256(b"content A") hash2 = compute_sha256(b"content B") assert hash1 != hash2 def test_single_bit_change_different_hash(self): """Test single bit change produces completely different hash.""" content1 = b"\x00" * 100 content2 = b"\x00" * 99 + b"\x01" hash1 = compute_sha256(content1) hash2 = compute_sha256(content2) assert hash1 != hash2 def test_binary_content(self): """Test hashing binary content with all byte values.""" result = compute_sha256(TEST_CONTENT_BINARY) assert result == TEST_HASH_BINARY assert len(result) == 64 def test_large_content(self): """Test hashing larger content (1MB).""" large_content = b"x" * (1024 * 1024) result = compute_sha256(large_content) expected = hashlib.sha256(large_content).hexdigest() assert result == expected def test_none_content_raises_error(self): """Test None content raises ChecksumError.""" with pytest.raises(ChecksumError, match="Cannot compute hash of None"): compute_sha256(None) class TestComputeSHA256Stream: """Tests for compute_sha256_stream function.""" def test_file_like_object(self): """Test hashing from file-like object.""" file_obj = io.BytesIO(TEST_CONTENT_HELLO) result = compute_sha256_stream(file_obj) assert result == TEST_HASH_HELLO def test_iterator(self): """Test hashing from iterator of chunks.""" def chunk_iterator(): yield b"Hello, " yield b"World!" result = compute_sha256_stream(chunk_iterator()) assert result == TEST_HASH_HELLO def test_various_chunk_sizes_same_result(self): """Test different chunk sizes produce same hash.""" content = b"x" * 10000 expected = hashlib.sha256(content).hexdigest() for chunk_size in [1, 10, 100, 1000, 8192]: file_obj = io.BytesIO(content) result = compute_sha256_stream(file_obj, chunk_size=chunk_size) assert result == expected, f"Failed for chunk_size={chunk_size}" def test_single_byte_chunks(self): """Test with 1-byte chunks (edge case).""" content = b"ABC" file_obj = io.BytesIO(content) result = compute_sha256_stream(file_obj, chunk_size=1) expected = hashlib.sha256(content).hexdigest() assert result == expected def test_empty_stream(self): """Test empty stream returns empty content hash.""" file_obj = io.BytesIO(b"") result = compute_sha256_stream(file_obj) assert result == TEST_HASH_EMPTY # ============================================================================= # Unit Tests - Hash Validation # ============================================================================= class TestIsValidSHA256: """Tests for is_valid_sha256 function.""" def test_valid_hash_lowercase(self): """Test valid lowercase hash.""" assert is_valid_sha256(TEST_HASH_HELLO) is True def test_valid_hash_uppercase(self): """Test valid uppercase hash.""" assert is_valid_sha256(TEST_HASH_HELLO.upper()) is True def test_valid_hash_mixed_case(self): """Test valid mixed case hash.""" mixed = TEST_HASH_HELLO[:32].upper() + TEST_HASH_HELLO[32:].lower() assert is_valid_sha256(mixed) is True def test_invalid_too_short(self): """Test hash that's too short.""" assert is_valid_sha256(INVALID_HASH_TOO_SHORT) is False def test_invalid_too_long(self): """Test hash that's too long.""" assert is_valid_sha256(INVALID_HASH_TOO_LONG) is False def test_invalid_non_hex(self): """Test hash with non-hex characters.""" assert is_valid_sha256(INVALID_HASH_NON_HEX) is False def test_invalid_empty(self): """Test empty string.""" assert is_valid_sha256(INVALID_HASH_EMPTY) is False def test_invalid_none(self): """Test None value.""" assert is_valid_sha256(None) is False class TestSHA256ToBase64: """Tests for sha256_to_base64 function.""" def test_converts_to_base64(self): """Test conversion to base64.""" result = sha256_to_base64(TEST_HASH_HELLO) # Verify it's valid base64 import base64 decoded = base64.b64decode(result) assert len(decoded) == 32 # SHA256 is 32 bytes def test_invalid_hash_raises_error(self): """Test invalid hash raises InvalidHashFormatError.""" with pytest.raises(InvalidHashFormatError): sha256_to_base64(INVALID_HASH_TOO_SHORT) # ============================================================================= # Unit Tests - Verification Functions # ============================================================================= class TestVerifyChecksum: """Tests for verify_checksum function.""" def test_matching_checksum_returns_true(self): """Test matching checksum returns True.""" result = verify_checksum(TEST_CONTENT_HELLO, TEST_HASH_HELLO) assert result is True def test_mismatched_checksum_returns_false(self): """Test mismatched checksum returns False.""" wrong_hash = "a" * 64 result = verify_checksum(TEST_CONTENT_HELLO, wrong_hash) assert result is False def test_case_insensitive_comparison(self): """Test comparison is case-insensitive.""" result = verify_checksum(TEST_CONTENT_HELLO, TEST_HASH_HELLO.upper()) assert result is True def test_invalid_hash_format_raises_error(self): """Test invalid hash format raises error.""" with pytest.raises(InvalidHashFormatError): verify_checksum(TEST_CONTENT_HELLO, INVALID_HASH_TOO_SHORT) class TestVerifyChecksumStrict: """Tests for verify_checksum_strict function.""" def test_matching_checksum_returns_none(self): """Test matching checksum doesn't raise.""" # Should not raise verify_checksum_strict(TEST_CONTENT_HELLO, TEST_HASH_HELLO) def test_mismatched_checksum_raises_error(self): """Test mismatched checksum raises ChecksumMismatchError.""" wrong_hash = "a" * 64 with pytest.raises(ChecksumMismatchError) as exc_info: verify_checksum_strict(TEST_CONTENT_HELLO, wrong_hash) error = exc_info.value assert error.expected == wrong_hash.lower() assert error.actual == TEST_HASH_HELLO assert error.size == len(TEST_CONTENT_HELLO) def test_error_includes_context(self): """Test error includes artifact_id and s3_key context.""" wrong_hash = "a" * 64 with pytest.raises(ChecksumMismatchError) as exc_info: verify_checksum_strict( TEST_CONTENT_HELLO, wrong_hash, artifact_id="test-artifact-123", s3_key="fruits/ab/cd/abcd1234...", ) error = exc_info.value assert error.artifact_id == "test-artifact-123" assert error.s3_key == "fruits/ab/cd/abcd1234..." # ============================================================================= # Unit Tests - HashingStreamWrapper # ============================================================================= class TestHashingStreamWrapper: """Tests for HashingStreamWrapper class.""" def test_computes_correct_hash(self): """Test wrapper computes correct hash.""" stream = io.BytesIO(TEST_CONTENT_HELLO) wrapper = HashingStreamWrapper(stream) # Consume the stream chunks = list(wrapper) # Verify hash assert wrapper.get_hash() == TEST_HASH_HELLO def test_yields_correct_chunks(self): """Test wrapper yields all content.""" stream = io.BytesIO(TEST_CONTENT_HELLO) wrapper = HashingStreamWrapper(stream) chunks = list(wrapper) content = b"".join(chunks) assert content == TEST_CONTENT_HELLO def test_tracks_bytes_read(self): """Test bytes_read property tracks correctly.""" stream = io.BytesIO(TEST_CONTENT_HELLO) wrapper = HashingStreamWrapper(stream) assert wrapper.bytes_read == 0 list(wrapper) # Consume assert wrapper.bytes_read == len(TEST_CONTENT_HELLO) def test_get_hash_before_iteration_consumes_stream(self): """Test get_hash() consumes stream if not already done.""" stream = io.BytesIO(TEST_CONTENT_HELLO) wrapper = HashingStreamWrapper(stream) # Call get_hash without iterating hash_result = wrapper.get_hash() assert hash_result == TEST_HASH_HELLO assert wrapper.bytes_read == len(TEST_CONTENT_HELLO) def test_get_hash_if_complete_before_iteration_returns_none(self): """Test get_hash_if_complete returns None before iteration.""" stream = io.BytesIO(TEST_CONTENT_HELLO) wrapper = HashingStreamWrapper(stream) assert wrapper.get_hash_if_complete() is None def test_get_hash_if_complete_after_iteration_returns_hash(self): """Test get_hash_if_complete returns hash after iteration.""" stream = io.BytesIO(TEST_CONTENT_HELLO) wrapper = HashingStreamWrapper(stream) list(wrapper) # Consume assert wrapper.get_hash_if_complete() == TEST_HASH_HELLO def test_custom_chunk_size(self): """Test custom chunk size is respected.""" content = b"x" * 1000 stream = io.BytesIO(content) wrapper = HashingStreamWrapper(stream, chunk_size=100) chunks = list(wrapper) # Each chunk should be at most 100 bytes for chunk in chunks[:-1]: # All but last assert len(chunk) == 100 # Total content should match assert b"".join(chunks) == content def test_iterator_interface(self): """Test wrapper supports iterator interface.""" stream = io.BytesIO(TEST_CONTENT_HELLO) wrapper = HashingStreamWrapper(stream) # Should be able to use for loop result = b"" for chunk in wrapper: result += chunk assert result == TEST_CONTENT_HELLO # ============================================================================= # Unit Tests - VerifyingStreamWrapper # ============================================================================= class TestVerifyingStreamWrapper: """Tests for VerifyingStreamWrapper class.""" def test_verify_success(self): """Test verification succeeds for matching content.""" stream = io.BytesIO(TEST_CONTENT_HELLO) wrapper = VerifyingStreamWrapper(stream, TEST_HASH_HELLO) # Consume stream list(wrapper) # Verify should succeed result = wrapper.verify() assert result is True assert wrapper.is_verified is True def test_verify_failure_raises_error(self): """Test verification failure raises ChecksumMismatchError.""" stream = io.BytesIO(TEST_CONTENT_HELLO) wrong_hash = "a" * 64 wrapper = VerifyingStreamWrapper(stream, wrong_hash) # Consume stream list(wrapper) # Verify should fail with pytest.raises(ChecksumMismatchError): wrapper.verify() assert wrapper.is_verified is False def test_verify_silent_success(self): """Test verify_silent returns True on success.""" stream = io.BytesIO(TEST_CONTENT_HELLO) wrapper = VerifyingStreamWrapper(stream, TEST_HASH_HELLO) list(wrapper) result = wrapper.verify_silent() assert result is True def test_verify_silent_failure(self): """Test verify_silent returns False on failure.""" stream = io.BytesIO(TEST_CONTENT_HELLO) wrong_hash = "a" * 64 wrapper = VerifyingStreamWrapper(stream, wrong_hash) list(wrapper) result = wrapper.verify_silent() assert result is False def test_invalid_expected_hash_raises_error(self): """Test invalid expected hash raises error at construction.""" stream = io.BytesIO(TEST_CONTENT_HELLO) with pytest.raises(InvalidHashFormatError): VerifyingStreamWrapper(stream, INVALID_HASH_TOO_SHORT) def test_on_failure_callback(self): """Test on_failure callback is called on verification failure.""" stream = io.BytesIO(TEST_CONTENT_HELLO) wrong_hash = "a" * 64 callback_called = [] def on_failure(error): callback_called.append(error) wrapper = VerifyingStreamWrapper(stream, wrong_hash, on_failure=on_failure) list(wrapper) with pytest.raises(ChecksumMismatchError): wrapper.verify() assert len(callback_called) == 1 assert isinstance(callback_called[0], ChecksumMismatchError) def test_get_actual_hash_after_iteration(self): """Test get_actual_hash returns hash after iteration.""" stream = io.BytesIO(TEST_CONTENT_HELLO) wrapper = VerifyingStreamWrapper(stream, TEST_HASH_HELLO) # Before iteration assert wrapper.get_actual_hash() is None list(wrapper) # After iteration assert wrapper.get_actual_hash() == TEST_HASH_HELLO def test_includes_context_in_error(self): """Test error includes artifact_id and s3_key.""" stream = io.BytesIO(TEST_CONTENT_HELLO) wrong_hash = "a" * 64 wrapper = VerifyingStreamWrapper( stream, wrong_hash, artifact_id="test-artifact", s3_key="test/key", ) list(wrapper) with pytest.raises(ChecksumMismatchError) as exc_info: wrapper.verify() error = exc_info.value assert error.artifact_id == "test-artifact" assert error.s3_key == "test/key" # ============================================================================= # Unit Tests - ChecksumMismatchError # ============================================================================= class TestChecksumMismatchError: """Tests for ChecksumMismatchError class.""" def test_to_dict(self): """Test to_dict returns proper dictionary.""" error = ChecksumMismatchError( expected="a" * 64, actual="b" * 64, artifact_id="test-123", s3_key="test/key", size=1024, ) result = error.to_dict() assert result["error"] == "checksum_mismatch" assert result["expected"] == "a" * 64 assert result["actual"] == "b" * 64 assert result["artifact_id"] == "test-123" assert result["s3_key"] == "test/key" assert result["size"] == 1024 def test_message_format(self): """Test error message format.""" error = ChecksumMismatchError( expected="a" * 64, actual="b" * 64, ) assert "verification failed" in str(error).lower() assert "expected" in str(error).lower() def test_custom_message(self): """Test custom message is used.""" error = ChecksumMismatchError( expected="a" * 64, actual="b" * 64, message="Custom error message", ) assert str(error) == "Custom error message" # ============================================================================= # Corruption Simulation Tests # ============================================================================= class TestCorruptionDetection: """Tests for detecting corrupted content.""" def test_detect_truncated_content(self): """Test detection of truncated content.""" original = TEST_CONTENT_HELLO truncated = original[:-1] # Remove last byte original_hash = compute_sha256(original) truncated_hash = compute_sha256(truncated) assert original_hash != truncated_hash assert verify_checksum(truncated, original_hash) is False def test_detect_extra_bytes(self): """Test detection of content with extra bytes.""" original = TEST_CONTENT_HELLO extended = original + b"\x00" # Add null byte original_hash = compute_sha256(original) assert verify_checksum(extended, original_hash) is False def test_detect_single_bit_flip(self): """Test detection of single bit flip.""" original = TEST_CONTENT_HELLO # Flip first bit of first byte corrupted = bytes([original[0] ^ 0x01]) + original[1:] original_hash = compute_sha256(original) assert verify_checksum(corrupted, original_hash) is False def test_detect_wrong_content(self): """Test detection of completely different content.""" original = TEST_CONTENT_HELLO different = b"Something completely different" original_hash = compute_sha256(original) assert verify_checksum(different, original_hash) is False def test_detect_empty_vs_nonempty(self): """Test detection of empty content vs non-empty.""" original = TEST_CONTENT_HELLO empty = b"" original_hash = compute_sha256(original) assert verify_checksum(empty, original_hash) is False def test_streaming_detection_of_corruption(self): """Test VerifyingStreamWrapper detects corruption.""" original = b"Original content that will be corrupted" original_hash = compute_sha256(original) # Corrupt the content corrupted = b"Corrupted content that is different" stream = io.BytesIO(corrupted) wrapper = VerifyingStreamWrapper(stream, original_hash) list(wrapper) # Consume with pytest.raises(ChecksumMismatchError): wrapper.verify() # ============================================================================= # Edge Case Tests # ============================================================================= class TestEdgeCases: """Tests for edge cases and boundary conditions.""" def test_null_bytes_in_content(self): """Test content with null bytes.""" content = b"\x00\x00\x00" hash_result = compute_sha256(content) assert verify_checksum(content, hash_result) is True def test_whitespace_only_content(self): """Test content with only whitespace.""" content = b" \t\n\r " hash_result = compute_sha256(content) assert verify_checksum(content, hash_result) is True def test_large_content_streaming(self): """Test streaming verification of large content.""" # 1MB of content large_content = b"x" * (1024 * 1024) expected_hash = compute_sha256(large_content) stream = io.BytesIO(large_content) wrapper = VerifyingStreamWrapper(stream, expected_hash) # Consume and verify chunks = list(wrapper) assert wrapper.verify() is True assert b"".join(chunks) == large_content def test_unicode_bytes_content(self): """Test content with unicode bytes.""" content = "Hello, δΈ–η•Œ! 🌍".encode("utf-8") hash_result = compute_sha256(content) assert verify_checksum(content, hash_result) is True def test_maximum_chunk_size_larger_than_content(self): """Test chunk size larger than content.""" content = b"small" stream = io.BytesIO(content) wrapper = HashingStreamWrapper(stream, chunk_size=1024 * 1024) chunks = list(wrapper) assert len(chunks) == 1 assert chunks[0] == content assert wrapper.get_hash() == compute_sha256(content)