import hashlib import logging from typing import BinaryIO, Tuple, Optional, Dict, Any, Generator import boto3 from botocore.config import Config from botocore.exceptions import ClientError from .config import get_settings settings = get_settings() logger = logging.getLogger(__name__) # Threshold for multipart upload (100MB) MULTIPART_THRESHOLD = 100 * 1024 * 1024 # Chunk size for multipart upload (10MB) MULTIPART_CHUNK_SIZE = 10 * 1024 * 1024 # Chunk size for streaming hash computation HASH_CHUNK_SIZE = 8 * 1024 * 1024 class S3Storage: def __init__(self): config = Config(s3={"addressing_style": "path"} if settings.s3_use_path_style else {}) self.client = boto3.client( "s3", endpoint_url=settings.s3_endpoint if settings.s3_endpoint else None, region_name=settings.s3_region, aws_access_key_id=settings.s3_access_key_id, aws_secret_access_key=settings.s3_secret_access_key, config=config, ) self.bucket = settings.s3_bucket # Store active multipart uploads for resumable support self._active_uploads: Dict[str, Dict[str, Any]] = {} def store(self, file: BinaryIO, content_length: Optional[int] = None) -> Tuple[str, int, str]: """ Store a file and return its SHA256 hash, size, and s3_key. Content-addressable: if the file already exists, just return the hash. Uses multipart upload for files larger than MULTIPART_THRESHOLD. """ # For small files or unknown size, use the simple approach if content_length is None or content_length < MULTIPART_THRESHOLD: return self._store_simple(file) else: return self._store_multipart(file, content_length) def _store_simple(self, file: BinaryIO) -> Tuple[str, int, str]: """Store a small file using simple put_object""" # Read file and compute hash content = file.read() sha256_hash = hashlib.sha256(content).hexdigest() size = len(content) # Check if already exists s3_key = f"fruits/{sha256_hash[:2]}/{sha256_hash[2:4]}/{sha256_hash}" if not self._exists(s3_key): self.client.put_object( Bucket=self.bucket, Key=s3_key, Body=content, ) return sha256_hash, size, s3_key def _store_multipart(self, file: BinaryIO, content_length: int) -> Tuple[str, int, str]: """Store a large file using S3 multipart upload with streaming hash computation""" # First pass: compute hash by streaming through file hasher = hashlib.sha256() size = 0 # Read file in chunks to compute hash while True: chunk = file.read(HASH_CHUNK_SIZE) if not chunk: break hasher.update(chunk) size += len(chunk) sha256_hash = hasher.hexdigest() s3_key = f"fruits/{sha256_hash[:2]}/{sha256_hash[2:4]}/{sha256_hash}" # Check if already exists (deduplication) if self._exists(s3_key): return sha256_hash, size, s3_key # Seek back to start for upload file.seek(0) # Start multipart upload mpu = self.client.create_multipart_upload(Bucket=self.bucket, Key=s3_key) upload_id = mpu["UploadId"] try: parts = [] part_number = 1 while True: chunk = file.read(MULTIPART_CHUNK_SIZE) if not chunk: break response = self.client.upload_part( Bucket=self.bucket, Key=s3_key, UploadId=upload_id, PartNumber=part_number, Body=chunk, ) parts.append({ "PartNumber": part_number, "ETag": response["ETag"], }) part_number += 1 # Complete multipart upload self.client.complete_multipart_upload( Bucket=self.bucket, Key=s3_key, UploadId=upload_id, MultipartUpload={"Parts": parts}, ) return sha256_hash, size, s3_key except Exception as e: # Abort multipart upload on failure logger.error(f"Multipart upload failed: {e}") self.client.abort_multipart_upload( Bucket=self.bucket, Key=s3_key, UploadId=upload_id, ) raise def store_streaming(self, chunks: Generator[bytes, None, None]) -> Tuple[str, int, str]: """ Store a file from a stream of chunks. First accumulates to compute hash, then uploads. For truly large files, consider using initiate_resumable_upload instead. """ # Accumulate chunks and compute hash hasher = hashlib.sha256() all_chunks = [] size = 0 for chunk in chunks: hasher.update(chunk) all_chunks.append(chunk) size += len(chunk) sha256_hash = hasher.hexdigest() s3_key = f"fruits/{sha256_hash[:2]}/{sha256_hash[2:4]}/{sha256_hash}" # Check if already exists if self._exists(s3_key): return sha256_hash, size, s3_key # Upload based on size if size < MULTIPART_THRESHOLD: content = b"".join(all_chunks) self.client.put_object(Bucket=self.bucket, Key=s3_key, Body=content) else: # Use multipart for large files mpu = self.client.create_multipart_upload(Bucket=self.bucket, Key=s3_key) upload_id = mpu["UploadId"] try: parts = [] part_number = 1 buffer = b"" for chunk in all_chunks: buffer += chunk while len(buffer) >= MULTIPART_CHUNK_SIZE: part_data = buffer[:MULTIPART_CHUNK_SIZE] buffer = buffer[MULTIPART_CHUNK_SIZE:] response = self.client.upload_part( Bucket=self.bucket, Key=s3_key, UploadId=upload_id, PartNumber=part_number, Body=part_data, ) parts.append({ "PartNumber": part_number, "ETag": response["ETag"], }) part_number += 1 # Upload remaining buffer if buffer: response = self.client.upload_part( Bucket=self.bucket, Key=s3_key, UploadId=upload_id, PartNumber=part_number, Body=buffer, ) parts.append({ "PartNumber": part_number, "ETag": response["ETag"], }) self.client.complete_multipart_upload( Bucket=self.bucket, Key=s3_key, UploadId=upload_id, MultipartUpload={"Parts": parts}, ) except Exception as e: logger.error(f"Streaming multipart upload failed: {e}") self.client.abort_multipart_upload( Bucket=self.bucket, Key=s3_key, UploadId=upload_id, ) raise return sha256_hash, size, s3_key def initiate_resumable_upload(self, expected_hash: str) -> Dict[str, Any]: """ Initiate a resumable upload session. Returns upload session info including upload_id. """ s3_key = f"fruits/{expected_hash[:2]}/{expected_hash[2:4]}/{expected_hash}" # Check if already exists if self._exists(s3_key): return { "upload_id": None, "s3_key": s3_key, "already_exists": True, "parts": [], } mpu = self.client.create_multipart_upload(Bucket=self.bucket, Key=s3_key) upload_id = mpu["UploadId"] session = { "upload_id": upload_id, "s3_key": s3_key, "already_exists": False, "parts": [], "expected_hash": expected_hash, } self._active_uploads[upload_id] = session return session def upload_part(self, upload_id: str, part_number: int, data: bytes) -> Dict[str, Any]: """ Upload a part for a resumable upload. Returns part info including ETag. """ session = self._active_uploads.get(upload_id) if not session: raise ValueError(f"Unknown upload session: {upload_id}") response = self.client.upload_part( Bucket=self.bucket, Key=session["s3_key"], UploadId=upload_id, PartNumber=part_number, Body=data, ) part_info = { "PartNumber": part_number, "ETag": response["ETag"], } session["parts"].append(part_info) return part_info def complete_resumable_upload(self, upload_id: str) -> Tuple[str, str]: """ Complete a resumable upload. Returns (sha256_hash, s3_key). """ session = self._active_uploads.get(upload_id) if not session: raise ValueError(f"Unknown upload session: {upload_id}") # Sort parts by part number sorted_parts = sorted(session["parts"], key=lambda x: x["PartNumber"]) self.client.complete_multipart_upload( Bucket=self.bucket, Key=session["s3_key"], UploadId=upload_id, MultipartUpload={"Parts": sorted_parts}, ) # Clean up session del self._active_uploads[upload_id] return session["expected_hash"], session["s3_key"] def abort_resumable_upload(self, upload_id: str): """Abort a resumable upload""" session = self._active_uploads.get(upload_id) if session: self.client.abort_multipart_upload( Bucket=self.bucket, Key=session["s3_key"], UploadId=upload_id, ) del self._active_uploads[upload_id] def list_upload_parts(self, upload_id: str) -> list: """List uploaded parts for a resumable upload (for resume support)""" session = self._active_uploads.get(upload_id) if not session: raise ValueError(f"Unknown upload session: {upload_id}") response = self.client.list_parts( Bucket=self.bucket, Key=session["s3_key"], UploadId=upload_id, ) return response.get("Parts", []) def get(self, s3_key: str) -> bytes: """Retrieve a file by its S3 key""" response = self.client.get_object(Bucket=self.bucket, Key=s3_key) return response["Body"].read() def get_stream(self, s3_key: str, range_header: Optional[str] = None): """ Get a streaming response for a file. Supports range requests for partial downloads. Returns (stream, content_length, content_range, accept_ranges) """ kwargs = {"Bucket": self.bucket, "Key": s3_key} if range_header: kwargs["Range"] = range_header response = self.client.get_object(**kwargs) content_length = response.get("ContentLength", 0) content_range = response.get("ContentRange") return response["Body"], content_length, content_range def get_object_info(self, s3_key: str) -> Dict[str, Any]: """Get object metadata without downloading content""" try: response = self.client.head_object(Bucket=self.bucket, Key=s3_key) return { "size": response.get("ContentLength", 0), "content_type": response.get("ContentType"), "last_modified": response.get("LastModified"), "etag": response.get("ETag"), } except ClientError: return None def _exists(self, s3_key: str) -> bool: """Check if an object exists""" try: self.client.head_object(Bucket=self.bucket, Key=s3_key) return True except ClientError: return False def delete(self, s3_key: str) -> bool: """Delete an object""" try: self.client.delete_object(Bucket=self.bucket, Key=s3_key) return True except ClientError: return False # Singleton instance _storage = None def get_storage() -> S3Storage: global _storage if _storage is None: _storage = S3Storage() return _storage