Add comprehensive upload/download tests and streaming enhancements (#38, #40, #42, #43)

This commit is contained in:
Mondo Diaz
2026-01-21 09:35:12 -06:00
parent f7ffc1c877
commit 584acd1e90
23 changed files with 5385 additions and 405 deletions

View File

@@ -170,6 +170,62 @@ def _run_migrations():
END IF;
END $$;
""",
# Create ref_count trigger functions for tags (ensures triggers exist even if initial migration wasn't run)
"""
CREATE OR REPLACE FUNCTION increment_artifact_ref_count()
RETURNS TRIGGER AS $$
BEGIN
UPDATE artifacts SET ref_count = ref_count + 1 WHERE id = NEW.artifact_id;
RETURN NEW;
END;
$$ LANGUAGE plpgsql;
""",
"""
CREATE OR REPLACE FUNCTION decrement_artifact_ref_count()
RETURNS TRIGGER AS $$
BEGIN
UPDATE artifacts SET ref_count = ref_count - 1 WHERE id = OLD.artifact_id;
RETURN OLD;
END;
$$ LANGUAGE plpgsql;
""",
"""
CREATE OR REPLACE FUNCTION update_artifact_ref_count()
RETURNS TRIGGER AS $$
BEGIN
IF OLD.artifact_id != NEW.artifact_id THEN
UPDATE artifacts SET ref_count = ref_count - 1 WHERE id = OLD.artifact_id;
UPDATE artifacts SET ref_count = ref_count + 1 WHERE id = NEW.artifact_id;
END IF;
RETURN NEW;
END;
$$ LANGUAGE plpgsql;
""",
# Create triggers for tags ref_count management
"""
DO $$
BEGIN
-- Drop and recreate triggers to ensure they're current
DROP TRIGGER IF EXISTS tags_ref_count_insert_trigger ON tags;
CREATE TRIGGER tags_ref_count_insert_trigger
AFTER INSERT ON tags
FOR EACH ROW
EXECUTE FUNCTION increment_artifact_ref_count();
DROP TRIGGER IF EXISTS tags_ref_count_delete_trigger ON tags;
CREATE TRIGGER tags_ref_count_delete_trigger
AFTER DELETE ON tags
FOR EACH ROW
EXECUTE FUNCTION decrement_artifact_ref_count();
DROP TRIGGER IF EXISTS tags_ref_count_update_trigger ON tags;
CREATE TRIGGER tags_ref_count_update_trigger
AFTER UPDATE ON tags
FOR EACH ROW
WHEN (OLD.artifact_id IS DISTINCT FROM NEW.artifact_id)
EXECUTE FUNCTION update_artifact_ref_count();
END $$;
""",
# Create ref_count trigger functions for package_versions
"""
CREATE OR REPLACE FUNCTION increment_version_ref_count()
@@ -210,7 +266,7 @@ def _run_migrations():
END $$;
""",
# Migrate existing semver tags to package_versions
"""
r"""
DO $$
BEGIN
IF EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = 'package_versions') THEN

View File

@@ -82,6 +82,7 @@ from .schemas import (
ResumableUploadCompleteRequest,
ResumableUploadCompleteResponse,
ResumableUploadStatusResponse,
UploadProgressResponse,
GlobalSearchResponse,
SearchResultProject,
SearchResultPackage,
@@ -143,6 +144,31 @@ def sanitize_filename(filename: str) -> str:
return re.sub(r'[\r\n"]', "", filename)
def build_content_disposition(filename: str) -> str:
"""Build a Content-Disposition header value with proper encoding.
For ASCII filenames, uses simple: attachment; filename="name"
For non-ASCII filenames, uses RFC 5987 encoding with UTF-8.
"""
from urllib.parse import quote
sanitized = sanitize_filename(filename)
# Check if filename is pure ASCII
try:
sanitized.encode('ascii')
# Pure ASCII - simple format
return f'attachment; filename="{sanitized}"'
except UnicodeEncodeError:
# Non-ASCII - use RFC 5987 encoding
# Provide both filename (ASCII fallback) and filename* (UTF-8 encoded)
ascii_fallback = sanitized.encode('ascii', errors='replace').decode('ascii')
# RFC 5987: filename*=charset'language'encoded_value
# We use UTF-8 encoding and percent-encode non-ASCII chars
encoded = quote(sanitized, safe='')
return f'attachment; filename="{ascii_fallback}"; filename*=UTF-8\'\'{encoded}'
def get_user_id_from_request(
request: Request,
db: Session,
@@ -2258,10 +2284,56 @@ def upload_artifact(
"""
Upload an artifact to a package.
Headers:
- X-Checksum-SHA256: Optional client-provided SHA256 for verification
- User-Agent: Captured for audit purposes
- Authorization: Bearer <api-key> for authentication
**Size Limits:**
- Minimum: 1 byte (empty files rejected)
- Maximum: 10GB (configurable via ORCHARD_MAX_FILE_SIZE)
- Files > 100MB automatically use S3 multipart upload
**Headers:**
- `X-Checksum-SHA256`: Optional SHA256 hash for server-side verification
- `Content-Length`: File size (required for early rejection of oversized files)
- `Authorization`: Bearer <api-key> for authentication
**Deduplication:**
Content-addressable storage automatically deduplicates identical files.
If the same content is uploaded multiple times, only one copy is stored.
**Response Metrics:**
- `duration_ms`: Upload duration in milliseconds
- `throughput_mbps`: Upload throughput in MB/s
- `deduplicated`: True if content already existed
**Example (curl):**
```bash
curl -X POST "http://localhost:8080/api/v1/project/myproject/mypackage/upload" \\
-H "Authorization: Bearer <api-key>" \\
-F "file=@myfile.tar.gz" \\
-F "tag=v1.0.0"
```
**Example (Python requests):**
```python
import requests
with open('myfile.tar.gz', 'rb') as f:
response = requests.post(
'http://localhost:8080/api/v1/project/myproject/mypackage/upload',
files={'file': f},
data={'tag': 'v1.0.0'},
headers={'Authorization': 'Bearer <api-key>'}
)
```
**Example (JavaScript fetch):**
```javascript
const formData = new FormData();
formData.append('file', fileInput.files[0]);
formData.append('tag', 'v1.0.0');
const response = await fetch('/api/v1/project/myproject/mypackage/upload', {
method: 'POST',
headers: { 'Authorization': 'Bearer <api-key>' },
body: formData
});
```
"""
start_time = time.time()
settings = get_settings()
@@ -2363,6 +2435,30 @@ def upload_artifact(
except StorageError as e:
logger.error(f"Storage error during upload: {e}")
raise HTTPException(status_code=500, detail="Internal storage error")
except (ConnectionResetError, BrokenPipeError) as e:
# Client disconnected during upload
logger.warning(
f"Client disconnected during upload: project={project_name} "
f"package={package_name} filename={file.filename} error={e}"
)
raise HTTPException(
status_code=499, # Client Closed Request (nginx convention)
detail="Client disconnected during upload",
)
except Exception as e:
# Catch-all for unexpected errors including client disconnects
error_str = str(e).lower()
if "connection" in error_str or "broken pipe" in error_str or "reset" in error_str:
logger.warning(
f"Client connection error during upload: project={project_name} "
f"package={package_name} filename={file.filename} error={e}"
)
raise HTTPException(
status_code=499,
detail="Client connection error during upload",
)
logger.error(f"Unexpected error during upload: {e}", exc_info=True)
raise HTTPException(status_code=500, detail="Internal server error during upload")
# Verify client-provided checksum if present
checksum_verified = True
@@ -2555,6 +2651,12 @@ def upload_artifact(
detail="Failed to save upload record. Please retry.",
)
# Calculate throughput
throughput_mbps = None
if duration_ms > 0:
duration_seconds = duration_ms / 1000.0
throughput_mbps = round((storage_result.size / (1024 * 1024)) / duration_seconds, 2)
return UploadResponse(
artifact_id=storage_result.sha256,
sha256=storage_result.sha256,
@@ -2574,6 +2676,8 @@ def upload_artifact(
content_type=artifact.content_type,
original_name=artifact.original_name,
created_at=artifact.created_at,
duration_ms=duration_ms,
throughput_mbps=throughput_mbps,
)
@@ -2591,8 +2695,46 @@ def init_resumable_upload(
storage: S3Storage = Depends(get_storage),
):
"""
Initialize a resumable upload session.
Client must provide the SHA256 hash of the file in advance.
Initialize a resumable upload session for large files.
Resumable uploads allow uploading large files in chunks, with the ability
to resume after interruption. The client must compute the SHA256 hash
of the entire file before starting.
**Workflow:**
1. POST /upload/init - Initialize upload session (this endpoint)
2. PUT /upload/{upload_id}/part/{part_number} - Upload each part
3. GET /upload/{upload_id}/progress - Check upload progress (optional)
4. POST /upload/{upload_id}/complete - Finalize upload
5. DELETE /upload/{upload_id} - Abort upload (if needed)
**Chunk Size:**
Use the `chunk_size` returned in the response (10MB default).
Each part except the last must be exactly this size.
**Deduplication:**
If the expected_hash already exists in storage, the response will include
`already_exists: true` and no upload session is created.
**Example (curl):**
```bash
# Step 1: Initialize
curl -X POST "http://localhost:8080/api/v1/project/myproject/mypackage/upload/init" \\
-H "Authorization: Bearer <api-key>" \\
-H "Content-Type: application/json" \\
-d '{"expected_hash": "<sha256>", "filename": "large.tar.gz", "size": 104857600}'
# Step 2: Upload parts
curl -X PUT "http://localhost:8080/api/v1/project/myproject/mypackage/upload/<upload_id>/part/1" \\
-H "Authorization: Bearer <api-key>" \\
--data-binary @part1.bin
# Step 3: Complete
curl -X POST "http://localhost:8080/api/v1/project/myproject/mypackage/upload/<upload_id>/complete" \\
-H "Authorization: Bearer <api-key>" \\
-H "Content-Type: application/json" \\
-d '{"tag": "v1.0.0"}'
```
"""
user_id = get_user_id(request)
@@ -2686,6 +2828,10 @@ def init_resumable_upload(
# Initialize resumable upload
session = storage.initiate_resumable_upload(init_request.expected_hash)
# Set expected size for progress tracking
if session["upload_id"] and init_request.size:
storage.set_upload_expected_size(session["upload_id"], init_request.size)
return ResumableUploadInitResponse(
upload_id=session["upload_id"],
already_exists=False,
@@ -2752,6 +2898,64 @@ def upload_part(
raise HTTPException(status_code=404, detail=str(e))
@router.get(
"/api/v1/project/{project_name}/{package_name}/upload/{upload_id}/progress",
response_model=UploadProgressResponse,
)
def get_upload_progress(
project_name: str,
package_name: str,
upload_id: str,
db: Session = Depends(get_db),
storage: S3Storage = Depends(get_storage),
):
"""
Get progress information for an in-flight resumable upload.
Returns progress metrics including bytes uploaded, percent complete,
elapsed time, and throughput.
"""
# Validate project and package exist
project = db.query(Project).filter(Project.name == project_name).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
package = (
db.query(Package)
.filter(Package.project_id == project.id, Package.name == package_name)
.first()
)
if not package:
raise HTTPException(status_code=404, detail="Package not found")
progress = storage.get_upload_progress(upload_id)
if not progress:
# Return not_found status instead of 404 to allow polling
return UploadProgressResponse(
upload_id=upload_id,
status="not_found",
bytes_uploaded=0,
)
from datetime import datetime, timezone
started_at_dt = None
if progress.get("started_at"):
started_at_dt = datetime.fromtimestamp(progress["started_at"], tz=timezone.utc)
return UploadProgressResponse(
upload_id=upload_id,
status=progress.get("status", "in_progress"),
bytes_uploaded=progress.get("bytes_uploaded", 0),
bytes_total=progress.get("bytes_total"),
percent_complete=progress.get("percent_complete"),
parts_uploaded=progress.get("parts_uploaded", 0),
parts_total=progress.get("parts_total"),
started_at=started_at_dt,
elapsed_seconds=progress.get("elapsed_seconds"),
throughput_mbps=progress.get("throughput_mbps"),
)
@router.post(
"/api/v1/project/{project_name}/{package_name}/upload/{upload_id}/complete"
)
@@ -2947,6 +3151,8 @@ def download_artifact(
storage: S3Storage = Depends(get_storage),
current_user: Optional[User] = Depends(get_current_user_optional),
range: Optional[str] = Header(None),
if_none_match: Optional[str] = Header(None, alias="If-None-Match"),
if_modified_since: Optional[str] = Header(None, alias="If-Modified-Since"),
mode: Optional[Literal["proxy", "redirect", "presigned"]] = Query(
default=None,
description="Download mode: proxy (stream through backend), redirect (302 to presigned URL), presigned (return JSON with URL)",
@@ -2963,6 +3169,15 @@ def download_artifact(
"""
Download an artifact by reference (tag name, artifact:hash, tag:name).
Supports conditional requests:
- If-None-Match: Returns 304 Not Modified if ETag matches
- If-Modified-Since: Returns 304 Not Modified if not modified since date
Supports range requests for partial downloads and resume:
- Range: bytes=0-1023 (first 1KB)
- Range: bytes=-1024 (last 1KB)
- Returns 206 Partial Content with Content-Range header
Verification modes:
- verify=false (default): No verification, maximum performance
- verify=true&verify_mode=stream: Compute hash while streaming, verify after completion.
@@ -2975,6 +3190,9 @@ def download_artifact(
- X-Content-Length: File size in bytes
- ETag: Artifact ID (SHA256)
- Digest: RFC 3230 format sha-256 hash
- Last-Modified: Artifact creation timestamp
- Cache-Control: Immutable caching for content-addressable storage
- Accept-Ranges: bytes (advertises range request support)
When verify=true:
- X-Verified: 'true' if verified, 'false' if verification failed
@@ -2999,6 +3217,52 @@ def download_artifact(
filename = sanitize_filename(artifact.original_name or f"{artifact.id}")
# Format Last-Modified header (RFC 7231 format)
last_modified = None
last_modified_str = None
if artifact.created_at:
last_modified = artifact.created_at
if last_modified.tzinfo is None:
last_modified = last_modified.replace(tzinfo=timezone.utc)
last_modified_str = last_modified.strftime("%a, %d %b %Y %H:%M:%S GMT")
# Handle conditional requests (If-None-Match, If-Modified-Since)
# Return 304 Not Modified if content hasn't changed
artifact_etag = f'"{artifact.id}"'
if if_none_match:
# Strip quotes and compare with artifact ETag
client_etag = if_none_match.strip().strip('"')
if client_etag == artifact.id or if_none_match == artifact_etag:
return Response(
status_code=304,
headers={
"ETag": artifact_etag,
"Cache-Control": "public, max-age=31536000, immutable",
**({"Last-Modified": last_modified_str} if last_modified_str else {}),
},
)
if if_modified_since and last_modified:
try:
# Parse If-Modified-Since header
from email.utils import parsedate_to_datetime
client_date = parsedate_to_datetime(if_modified_since)
if client_date.tzinfo is None:
client_date = client_date.replace(tzinfo=timezone.utc)
# If artifact hasn't been modified since client's date, return 304
if last_modified <= client_date:
return Response(
status_code=304,
headers={
"ETag": artifact_etag,
"Cache-Control": "public, max-age=31536000, immutable",
**({"Last-Modified": last_modified_str} if last_modified_str else {}),
},
)
except (ValueError, TypeError):
pass # Invalid date format, ignore and continue with download
# Audit log download
user_id = get_user_id(request)
_log_audit(
@@ -3016,22 +3280,28 @@ def download_artifact(
)
db.commit()
# Build common checksum headers (always included)
checksum_headers = {
# Build common headers (always included)
common_headers = {
"X-Checksum-SHA256": artifact.id,
"X-Content-Length": str(artifact.size),
"ETag": f'"{artifact.id}"',
"ETag": artifact_etag,
# Cache-Control: content-addressable storage is immutable
"Cache-Control": "public, max-age=31536000, immutable",
}
# Add Last-Modified header
if last_modified_str:
common_headers["Last-Modified"] = last_modified_str
# Add RFC 3230 Digest header
try:
digest_base64 = sha256_to_base64(artifact.id)
checksum_headers["Digest"] = f"sha-256={digest_base64}"
common_headers["Digest"] = f"sha-256={digest_base64}"
except Exception:
pass # Skip if conversion fails
# Add MD5 checksum if available
if artifact.checksum_md5:
checksum_headers["X-Checksum-MD5"] = artifact.checksum_md5
common_headers["X-Checksum-MD5"] = artifact.checksum_md5
# Determine download mode (query param overrides server default)
download_mode = mode or settings.download_mode
@@ -3071,15 +3341,29 @@ def download_artifact(
# Proxy mode (default fallback) - stream through backend
# Handle range requests (verification not supported for partial downloads)
if range:
stream, content_length, content_range = storage.get_stream(
artifact.s3_key, range
)
try:
stream, content_length, content_range = storage.get_stream(
artifact.s3_key, range
)
except Exception as e:
# S3 returns InvalidRange error for unsatisfiable ranges
error_str = str(e).lower()
if "invalidrange" in error_str or "range" in error_str:
raise HTTPException(
status_code=416,
detail="Range Not Satisfiable",
headers={
"Content-Range": f"bytes */{artifact.size}",
"Accept-Ranges": "bytes",
},
)
raise
headers = {
"Content-Disposition": f'attachment; filename="{filename}"',
"Content-Disposition": build_content_disposition(filename),
"Accept-Ranges": "bytes",
"Content-Length": str(content_length),
**checksum_headers,
**common_headers,
}
if content_range:
headers["Content-Range"] = content_range
@@ -3094,9 +3378,9 @@ def download_artifact(
# Full download with optional verification
base_headers = {
"Content-Disposition": f'attachment; filename="{filename}"',
"Content-Disposition": build_content_disposition(filename),
"Accept-Ranges": "bytes",
**checksum_headers,
**common_headers,
}
# Pre-verification mode: verify before streaming
@@ -3164,11 +3448,42 @@ def download_artifact(
},
)
# No verification - direct streaming
# No verification - direct streaming with completion logging
stream, content_length, _ = storage.get_stream(artifact.s3_key)
def logged_stream():
"""Generator that yields chunks and logs completion/disconnection."""
import time
start_time = time.time()
bytes_sent = 0
try:
for chunk in stream:
bytes_sent += len(chunk)
yield chunk
# Download completed successfully
duration = time.time() - start_time
throughput_mbps = (bytes_sent / (1024 * 1024)) / duration if duration > 0 else 0
logger.info(
f"Download completed: artifact={artifact.id[:16]}... "
f"bytes={bytes_sent} duration={duration:.2f}s throughput={throughput_mbps:.2f}MB/s"
)
except GeneratorExit:
# Client disconnected before download completed
duration = time.time() - start_time
logger.warning(
f"Download interrupted: artifact={artifact.id[:16]}... "
f"bytes_sent={bytes_sent}/{content_length} duration={duration:.2f}s"
)
except Exception as e:
duration = time.time() - start_time
logger.error(
f"Download error: artifact={artifact.id[:16]}... "
f"bytes_sent={bytes_sent} duration={duration:.2f}s error={e}"
)
raise
return StreamingResponse(
stream,
logged_stream(),
media_type=artifact.content_type or "application/octet-stream",
headers={
**base_headers,
@@ -3276,7 +3591,7 @@ def head_artifact(
# Build headers with checksum information
headers = {
"Content-Disposition": f'attachment; filename="{filename}"',
"Content-Disposition": build_content_disposition(filename),
"Accept-Ranges": "bytes",
"Content-Length": str(artifact.size),
"X-Artifact-Id": artifact.id,

View File

@@ -412,6 +412,9 @@ class UploadResponse(BaseModel):
content_type: Optional[str] = None
original_name: Optional[str] = None
created_at: Optional[datetime] = None
# Upload metrics (Issue #43)
duration_ms: Optional[int] = None # Upload duration in milliseconds
throughput_mbps: Optional[float] = None # Upload throughput in MB/s
# Resumable upload schemas
@@ -478,6 +481,21 @@ class ResumableUploadStatusResponse(BaseModel):
total_uploaded_bytes: int
class UploadProgressResponse(BaseModel):
"""Progress information for an in-flight upload"""
upload_id: str
status: str # 'in_progress', 'completed', 'failed', 'not_found'
bytes_uploaded: int = 0
bytes_total: Optional[int] = None
percent_complete: Optional[float] = None
parts_uploaded: int = 0
parts_total: Optional[int] = None
started_at: Optional[datetime] = None
elapsed_seconds: Optional[float] = None
throughput_mbps: Optional[float] = None
# Consumer schemas
class ConsumerResponse(BaseModel):
id: UUID

View File

@@ -378,10 +378,16 @@ class S3Storage:
"""
# First pass: compute all hashes by streaming through file
try:
import time
sha256_hasher = hashlib.sha256()
md5_hasher = hashlib.md5()
sha1_hasher = hashlib.sha1()
size = 0
hash_start_time = time.time()
last_log_time = hash_start_time
log_interval_seconds = 5 # Log progress every 5 seconds
logger.info(f"Computing hashes for large file: expected_size={content_length}")
# Read file in chunks to compute hashes
while True:
@@ -393,6 +399,18 @@ class S3Storage:
sha1_hasher.update(chunk)
size += len(chunk)
# Log hash computation progress periodically
current_time = time.time()
if current_time - last_log_time >= log_interval_seconds:
elapsed = current_time - hash_start_time
percent = (size / content_length) * 100 if content_length > 0 else 0
throughput = (size / (1024 * 1024)) / elapsed if elapsed > 0 else 0
logger.info(
f"Hash computation progress: bytes={size}/{content_length} ({percent:.1f}%) "
f"throughput={throughput:.2f}MB/s"
)
last_log_time = current_time
# Enforce file size limit during streaming (protection against spoofing)
if size > settings.max_file_size:
raise FileSizeExceededError(
@@ -405,6 +423,14 @@ class S3Storage:
sha256_hash = sha256_hasher.hexdigest()
md5_hash = md5_hasher.hexdigest()
sha1_hash = sha1_hasher.hexdigest()
# Log hash computation completion
hash_elapsed = time.time() - hash_start_time
hash_throughput = (size / (1024 * 1024)) / hash_elapsed if hash_elapsed > 0 else 0
logger.info(
f"Hash computation completed: hash={sha256_hash[:16]}... "
f"size={size} duration={hash_elapsed:.2f}s throughput={hash_throughput:.2f}MB/s"
)
except (HashComputationError, FileSizeExceededError):
raise
except Exception as e:
@@ -458,8 +484,19 @@ class S3Storage:
upload_id = mpu["UploadId"]
try:
import time
parts = []
part_number = 1
bytes_uploaded = 0
upload_start_time = time.time()
last_log_time = upload_start_time
log_interval_seconds = 5 # Log progress every 5 seconds
total_parts = (content_length + MULTIPART_CHUNK_SIZE - 1) // MULTIPART_CHUNK_SIZE
logger.info(
f"Starting multipart upload: hash={sha256_hash[:16]}... "
f"size={content_length} parts={total_parts}"
)
while True:
chunk = file.read(MULTIPART_CHUNK_SIZE)
@@ -479,8 +516,32 @@ class S3Storage:
"ETag": response["ETag"],
}
)
bytes_uploaded += len(chunk)
# Log progress periodically
current_time = time.time()
if current_time - last_log_time >= log_interval_seconds:
elapsed = current_time - upload_start_time
percent = (bytes_uploaded / content_length) * 100
throughput = (bytes_uploaded / (1024 * 1024)) / elapsed if elapsed > 0 else 0
logger.info(
f"Upload progress: hash={sha256_hash[:16]}... "
f"part={part_number}/{total_parts} "
f"bytes={bytes_uploaded}/{content_length} ({percent:.1f}%) "
f"throughput={throughput:.2f}MB/s"
)
last_log_time = current_time
part_number += 1
# Log completion
total_elapsed = time.time() - upload_start_time
final_throughput = (content_length / (1024 * 1024)) / total_elapsed if total_elapsed > 0 else 0
logger.info(
f"Multipart upload completed: hash={sha256_hash[:16]}... "
f"size={content_length} duration={total_elapsed:.2f}s throughput={final_throughput:.2f}MB/s"
)
# Complete multipart upload
complete_response = self.client.complete_multipart_upload(
Bucket=self.bucket,
@@ -502,12 +563,28 @@ class S3Storage:
except Exception as e:
# Abort multipart upload on failure
logger.error(f"Multipart upload failed: {e}")
self.client.abort_multipart_upload(
Bucket=self.bucket,
Key=s3_key,
UploadId=upload_id,
error_str = str(e).lower()
is_client_disconnect = (
isinstance(e, (ConnectionResetError, BrokenPipeError)) or
"connection" in error_str or "broken pipe" in error_str or "reset" in error_str
)
if is_client_disconnect:
logger.warning(
f"Multipart upload aborted (client disconnect): hash={sha256_hash[:16]}... "
f"parts_uploaded={len(parts)} bytes_uploaded={bytes_uploaded}"
)
else:
logger.error(f"Multipart upload failed: hash={sha256_hash[:16]}... error={e}")
try:
self.client.abort_multipart_upload(
Bucket=self.bucket,
Key=s3_key,
UploadId=upload_id,
)
logger.info(f"Multipart upload aborted and cleaned up: upload_id={upload_id[:16]}...")
except Exception as abort_error:
logger.error(f"Failed to abort multipart upload: {abort_error}")
raise
def initiate_resumable_upload(self, expected_hash: str) -> Dict[str, Any]:
@@ -529,12 +606,17 @@ class S3Storage:
mpu = self.client.create_multipart_upload(Bucket=self.bucket, Key=s3_key)
upload_id = mpu["UploadId"]
import time
session = {
"upload_id": upload_id,
"s3_key": s3_key,
"already_exists": False,
"parts": [],
"expected_hash": expected_hash,
"started_at": time.time(),
"bytes_uploaded": 0,
"expected_size": None, # Set when init provides size
"status": "in_progress",
}
self._active_uploads[upload_id] = session
return session
@@ -561,10 +643,57 @@ class S3Storage:
part_info = {
"PartNumber": part_number,
"ETag": response["ETag"],
"size": len(data),
}
session["parts"].append(part_info)
session["bytes_uploaded"] = session.get("bytes_uploaded", 0) + len(data)
return part_info
def get_upload_progress(self, upload_id: str) -> Optional[Dict[str, Any]]:
"""
Get progress information for a resumable upload.
Returns None if upload not found.
"""
import time
session = self._active_uploads.get(upload_id)
if not session:
return None
bytes_uploaded = session.get("bytes_uploaded", 0)
expected_size = session.get("expected_size")
started_at = session.get("started_at")
progress = {
"upload_id": upload_id,
"status": session.get("status", "in_progress"),
"bytes_uploaded": bytes_uploaded,
"bytes_total": expected_size,
"parts_uploaded": len(session.get("parts", [])),
"parts_total": None,
"started_at": started_at,
"elapsed_seconds": None,
"percent_complete": None,
"throughput_mbps": None,
}
if expected_size and expected_size > 0:
progress["percent_complete"] = round((bytes_uploaded / expected_size) * 100, 2)
progress["parts_total"] = (expected_size + MULTIPART_CHUNK_SIZE - 1) // MULTIPART_CHUNK_SIZE
if started_at:
elapsed = time.time() - started_at
progress["elapsed_seconds"] = round(elapsed, 2)
if elapsed > 0 and bytes_uploaded > 0:
progress["throughput_mbps"] = round((bytes_uploaded / (1024 * 1024)) / elapsed, 2)
return progress
def set_upload_expected_size(self, upload_id: str, size: int):
"""Set the expected size for an upload (for progress tracking)."""
session = self._active_uploads.get(upload_id)
if session:
session["expected_size"] = size
def complete_resumable_upload(self, upload_id: str) -> Tuple[str, str]:
"""
Complete a resumable upload.