import hashlib import logging from typing import BinaryIO, Tuple, Optional, Dict, Any, Generator, NamedTuple import boto3 from botocore.config import Config from botocore.exceptions import ClientError from .config import get_settings settings = get_settings() logger = logging.getLogger(__name__) # Threshold for multipart upload (100MB) MULTIPART_THRESHOLD = 100 * 1024 * 1024 # Chunk size for multipart upload (10MB) MULTIPART_CHUNK_SIZE = 10 * 1024 * 1024 # Chunk size for streaming hash computation HASH_CHUNK_SIZE = 8 * 1024 * 1024 class StorageResult(NamedTuple): """Result of storing a file with all computed checksums""" sha256: str size: int s3_key: str md5: Optional[str] = None sha1: Optional[str] = None s3_etag: Optional[str] = None class S3Storage: def __init__(self): config = Config(s3={"addressing_style": "path"} if settings.s3_use_path_style else {}) self.client = boto3.client( "s3", endpoint_url=settings.s3_endpoint if settings.s3_endpoint else None, region_name=settings.s3_region, aws_access_key_id=settings.s3_access_key_id, aws_secret_access_key=settings.s3_secret_access_key, config=config, ) self.bucket = settings.s3_bucket # Store active multipart uploads for resumable support self._active_uploads: Dict[str, Dict[str, Any]] = {} def store(self, file: BinaryIO, content_length: Optional[int] = None) -> StorageResult: """ Store a file and return StorageResult with all checksums. Content-addressable: if the file already exists, just return the hash. Uses multipart upload for files larger than MULTIPART_THRESHOLD. """ # For small files or unknown size, use the simple approach if content_length is None or content_length < MULTIPART_THRESHOLD: return self._store_simple(file) else: return self._store_multipart(file, content_length) def _store_simple(self, file: BinaryIO) -> StorageResult: """Store a small file using simple put_object""" # Read file and compute all hashes content = file.read() sha256_hash = hashlib.sha256(content).hexdigest() md5_hash = hashlib.md5(content).hexdigest() sha1_hash = hashlib.sha1(content).hexdigest() size = len(content) # Check if already exists s3_key = f"fruits/{sha256_hash[:2]}/{sha256_hash[2:4]}/{sha256_hash}" s3_etag = None if not self._exists(s3_key): response = self.client.put_object( Bucket=self.bucket, Key=s3_key, Body=content, ) s3_etag = response.get("ETag", "").strip('"') else: # Get existing ETag obj_info = self.get_object_info(s3_key) if obj_info: s3_etag = obj_info.get("etag", "").strip('"') return StorageResult( sha256=sha256_hash, size=size, s3_key=s3_key, md5=md5_hash, sha1=sha1_hash, s3_etag=s3_etag, ) def _store_multipart(self, file: BinaryIO, content_length: int) -> StorageResult: """Store a large file using S3 multipart upload with streaming hash computation""" # First pass: compute all hashes by streaming through file sha256_hasher = hashlib.sha256() md5_hasher = hashlib.md5() sha1_hasher = hashlib.sha1() size = 0 # Read file in chunks to compute hashes while True: chunk = file.read(HASH_CHUNK_SIZE) if not chunk: break sha256_hasher.update(chunk) md5_hasher.update(chunk) sha1_hasher.update(chunk) size += len(chunk) sha256_hash = sha256_hasher.hexdigest() md5_hash = md5_hasher.hexdigest() sha1_hash = sha1_hasher.hexdigest() s3_key = f"fruits/{sha256_hash[:2]}/{sha256_hash[2:4]}/{sha256_hash}" # Check if already exists (deduplication) if self._exists(s3_key): obj_info = self.get_object_info(s3_key) s3_etag = obj_info.get("etag", "").strip('"') if obj_info else None return StorageResult( sha256=sha256_hash, size=size, s3_key=s3_key, md5=md5_hash, sha1=sha1_hash, s3_etag=s3_etag, ) # Seek back to start for upload file.seek(0) # Start multipart upload mpu = self.client.create_multipart_upload(Bucket=self.bucket, Key=s3_key) upload_id = mpu["UploadId"] try: parts = [] part_number = 1 while True: chunk = file.read(MULTIPART_CHUNK_SIZE) if not chunk: break response = self.client.upload_part( Bucket=self.bucket, Key=s3_key, UploadId=upload_id, PartNumber=part_number, Body=chunk, ) parts.append({ "PartNumber": part_number, "ETag": response["ETag"], }) part_number += 1 # Complete multipart upload complete_response = self.client.complete_multipart_upload( Bucket=self.bucket, Key=s3_key, UploadId=upload_id, MultipartUpload={"Parts": parts}, ) s3_etag = complete_response.get("ETag", "").strip('"') return StorageResult( sha256=sha256_hash, size=size, s3_key=s3_key, md5=md5_hash, sha1=sha1_hash, s3_etag=s3_etag, ) except Exception as e: # Abort multipart upload on failure logger.error(f"Multipart upload failed: {e}") self.client.abort_multipart_upload( Bucket=self.bucket, Key=s3_key, UploadId=upload_id, ) raise def store_streaming(self, chunks: Generator[bytes, None, None]) -> StorageResult: """ Store a file from a stream of chunks. First accumulates to compute hash, then uploads. For truly large files, consider using initiate_resumable_upload instead. """ # Accumulate chunks and compute all hashes sha256_hasher = hashlib.sha256() md5_hasher = hashlib.md5() sha1_hasher = hashlib.sha1() all_chunks = [] size = 0 for chunk in chunks: sha256_hasher.update(chunk) md5_hasher.update(chunk) sha1_hasher.update(chunk) all_chunks.append(chunk) size += len(chunk) sha256_hash = sha256_hasher.hexdigest() md5_hash = md5_hasher.hexdigest() sha1_hash = sha1_hasher.hexdigest() s3_key = f"fruits/{sha256_hash[:2]}/{sha256_hash[2:4]}/{sha256_hash}" s3_etag = None # Check if already exists if self._exists(s3_key): obj_info = self.get_object_info(s3_key) s3_etag = obj_info.get("etag", "").strip('"') if obj_info else None return StorageResult( sha256=sha256_hash, size=size, s3_key=s3_key, md5=md5_hash, sha1=sha1_hash, s3_etag=s3_etag, ) # Upload based on size if size < MULTIPART_THRESHOLD: content = b"".join(all_chunks) response = self.client.put_object(Bucket=self.bucket, Key=s3_key, Body=content) s3_etag = response.get("ETag", "").strip('"') else: # Use multipart for large files mpu = self.client.create_multipart_upload(Bucket=self.bucket, Key=s3_key) upload_id = mpu["UploadId"] try: parts = [] part_number = 1 buffer = b"" for chunk in all_chunks: buffer += chunk while len(buffer) >= MULTIPART_CHUNK_SIZE: part_data = buffer[:MULTIPART_CHUNK_SIZE] buffer = buffer[MULTIPART_CHUNK_SIZE:] response = self.client.upload_part( Bucket=self.bucket, Key=s3_key, UploadId=upload_id, PartNumber=part_number, Body=part_data, ) parts.append({ "PartNumber": part_number, "ETag": response["ETag"], }) part_number += 1 # Upload remaining buffer if buffer: response = self.client.upload_part( Bucket=self.bucket, Key=s3_key, UploadId=upload_id, PartNumber=part_number, Body=buffer, ) parts.append({ "PartNumber": part_number, "ETag": response["ETag"], }) complete_response = self.client.complete_multipart_upload( Bucket=self.bucket, Key=s3_key, UploadId=upload_id, MultipartUpload={"Parts": parts}, ) s3_etag = complete_response.get("ETag", "").strip('"') except Exception as e: logger.error(f"Streaming multipart upload failed: {e}") self.client.abort_multipart_upload( Bucket=self.bucket, Key=s3_key, UploadId=upload_id, ) raise return StorageResult( sha256=sha256_hash, size=size, s3_key=s3_key, md5=md5_hash, sha1=sha1_hash, s3_etag=s3_etag, ) def initiate_resumable_upload(self, expected_hash: str) -> Dict[str, Any]: """ Initiate a resumable upload session. Returns upload session info including upload_id. """ s3_key = f"fruits/{expected_hash[:2]}/{expected_hash[2:4]}/{expected_hash}" # Check if already exists if self._exists(s3_key): return { "upload_id": None, "s3_key": s3_key, "already_exists": True, "parts": [], } mpu = self.client.create_multipart_upload(Bucket=self.bucket, Key=s3_key) upload_id = mpu["UploadId"] session = { "upload_id": upload_id, "s3_key": s3_key, "already_exists": False, "parts": [], "expected_hash": expected_hash, } self._active_uploads[upload_id] = session return session def upload_part(self, upload_id: str, part_number: int, data: bytes) -> Dict[str, Any]: """ Upload a part for a resumable upload. Returns part info including ETag. """ session = self._active_uploads.get(upload_id) if not session: raise ValueError(f"Unknown upload session: {upload_id}") response = self.client.upload_part( Bucket=self.bucket, Key=session["s3_key"], UploadId=upload_id, PartNumber=part_number, Body=data, ) part_info = { "PartNumber": part_number, "ETag": response["ETag"], } session["parts"].append(part_info) return part_info def complete_resumable_upload(self, upload_id: str) -> Tuple[str, str]: """ Complete a resumable upload. Returns (sha256_hash, s3_key). """ session = self._active_uploads.get(upload_id) if not session: raise ValueError(f"Unknown upload session: {upload_id}") # Sort parts by part number sorted_parts = sorted(session["parts"], key=lambda x: x["PartNumber"]) self.client.complete_multipart_upload( Bucket=self.bucket, Key=session["s3_key"], UploadId=upload_id, MultipartUpload={"Parts": sorted_parts}, ) # Clean up session del self._active_uploads[upload_id] return session["expected_hash"], session["s3_key"] def abort_resumable_upload(self, upload_id: str): """Abort a resumable upload""" session = self._active_uploads.get(upload_id) if session: self.client.abort_multipart_upload( Bucket=self.bucket, Key=session["s3_key"], UploadId=upload_id, ) del self._active_uploads[upload_id] def list_upload_parts(self, upload_id: str) -> list: """List uploaded parts for a resumable upload (for resume support)""" session = self._active_uploads.get(upload_id) if not session: raise ValueError(f"Unknown upload session: {upload_id}") response = self.client.list_parts( Bucket=self.bucket, Key=session["s3_key"], UploadId=upload_id, ) return response.get("Parts", []) def get(self, s3_key: str) -> bytes: """Retrieve a file by its S3 key""" response = self.client.get_object(Bucket=self.bucket, Key=s3_key) return response["Body"].read() def get_stream(self, s3_key: str, range_header: Optional[str] = None): """ Get a streaming response for a file. Supports range requests for partial downloads. Returns (stream, content_length, content_range, accept_ranges) """ kwargs = {"Bucket": self.bucket, "Key": s3_key} if range_header: kwargs["Range"] = range_header response = self.client.get_object(**kwargs) content_length = response.get("ContentLength", 0) content_range = response.get("ContentRange") return response["Body"], content_length, content_range def get_object_info(self, s3_key: str) -> Dict[str, Any]: """Get object metadata without downloading content""" try: response = self.client.head_object(Bucket=self.bucket, Key=s3_key) return { "size": response.get("ContentLength", 0), "content_type": response.get("ContentType"), "last_modified": response.get("LastModified"), "etag": response.get("ETag"), } except ClientError: return None def _exists(self, s3_key: str) -> bool: """Check if an object exists""" try: self.client.head_object(Bucket=self.bucket, Key=s3_key) return True except ClientError: return False def delete(self, s3_key: str) -> bool: """Delete an object""" try: self.client.delete_object(Bucket=self.bucket, Key=s3_key) return True except ClientError: return False # Singleton instance _storage = None def get_storage() -> S3Storage: global _storage if _storage is None: _storage = S3Storage() return _storage