Implement backend upload/download API enhancements

- Add S3 multipart upload support for files > 100MB
- Add resumable upload API endpoints (init, upload part, complete, abort, status)
- Add HTTP range request support for partial downloads
- Add HEAD request endpoint for artifact metadata
- Add format-specific metadata extraction (deb, rpm, tar.gz, wheel, jar, zip)
- Add format_metadata column to artifacts table
- Add database migration for schema updates
- Add deduplication indicator in upload response
- Set Accept-Ranges header on downloads
- Return Content-Length header on all downloads
This commit is contained in:
Mondo Diaz
2025-12-11 17:07:10 -06:00
parent cb3d62b02a
commit 6eb2f9db7b
6 changed files with 1118 additions and 20 deletions

View File

@@ -1,20 +1,51 @@
from sqlalchemy import create_engine
from sqlalchemy import create_engine, text
from sqlalchemy.orm import sessionmaker, Session
from typing import Generator
import logging
from .config import get_settings
from .models import Base
settings = get_settings()
logger = logging.getLogger(__name__)
engine = create_engine(settings.database_url, pool_pre_ping=True)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
def init_db():
"""Create all tables"""
"""Create all tables and run migrations"""
Base.metadata.create_all(bind=engine)
# Run migrations for schema updates
_run_migrations()
def _run_migrations():
"""Run manual migrations for schema updates"""
migrations = [
# Add format_metadata column to artifacts table
"""
DO $$
BEGIN
IF NOT EXISTS (
SELECT 1 FROM information_schema.columns
WHERE table_name = 'artifacts' AND column_name = 'format_metadata'
) THEN
ALTER TABLE artifacts ADD COLUMN format_metadata JSONB DEFAULT '{}';
END IF;
END $$;
""",
]
with engine.connect() as conn:
for migration in migrations:
try:
conn.execute(text(migration))
conn.commit()
except Exception as e:
logger.warning(f"Migration failed (may already be applied): {e}")
def get_db() -> Generator[Session, None, None]:
"""Dependency for getting database sessions"""

354
backend/app/metadata.py Normal file
View File

@@ -0,0 +1,354 @@
"""
Format-specific metadata extraction for uploaded artifacts.
Supports extracting version info and other metadata from package formats.
"""
import struct
import gzip
import tarfile
import io
import re
import logging
from typing import Dict, Any, Optional, BinaryIO
logger = logging.getLogger(__name__)
def extract_metadata(file: BinaryIO, filename: str, content_type: Optional[str] = None) -> Dict[str, Any]:
"""
Extract format-specific metadata from an uploaded file.
Returns a dict with extracted metadata fields.
"""
metadata = {}
# Determine format from filename extension
lower_filename = filename.lower() if filename else ""
try:
if lower_filename.endswith(".deb"):
metadata = extract_deb_metadata(file)
elif lower_filename.endswith(".rpm"):
metadata = extract_rpm_metadata(file)
elif lower_filename.endswith(".tar.gz") or lower_filename.endswith(".tgz"):
metadata = extract_tarball_metadata(file, filename)
elif lower_filename.endswith(".whl"):
metadata = extract_wheel_metadata(file)
elif lower_filename.endswith(".jar"):
metadata = extract_jar_metadata(file)
elif lower_filename.endswith(".zip"):
metadata = extract_zip_metadata(file)
except Exception as e:
logger.warning(f"Failed to extract metadata from {filename}: {e}")
# Always seek back to start after reading
try:
file.seek(0)
except Exception:
pass
return metadata
def extract_deb_metadata(file: BinaryIO) -> Dict[str, Any]:
"""
Extract metadata from a Debian .deb package.
Deb files are ar archives containing control.tar.gz with package info.
"""
metadata = {}
# Read ar archive header
ar_magic = file.read(8)
if ar_magic != b"!<arch>\n":
return metadata
# Parse ar archive to find control.tar.gz or control.tar.xz
while True:
# Read ar entry header (60 bytes)
header = file.read(60)
if len(header) < 60:
break
name = header[0:16].decode("ascii").strip()
size_str = header[48:58].decode("ascii").strip()
try:
size = int(size_str)
except ValueError:
break
if name.startswith("control.tar"):
# Read control archive
control_data = file.read(size)
# Decompress and read control file
try:
if name.endswith(".gz"):
control_data = gzip.decompress(control_data)
# Parse tar archive
with tarfile.open(fileobj=io.BytesIO(control_data), mode="r:*") as tar:
for member in tar.getmembers():
if member.name in ("./control", "control"):
f = tar.extractfile(member)
if f:
control_content = f.read().decode("utf-8", errors="replace")
metadata = parse_deb_control(control_content)
break
except Exception as e:
logger.debug(f"Failed to parse deb control: {e}")
break
else:
# Skip to next entry (align to 2 bytes)
file.seek(size + (size % 2), 1)
return metadata
def parse_deb_control(content: str) -> Dict[str, Any]:
"""Parse Debian control file format"""
metadata = {}
current_key = None
current_value = []
for line in content.split("\n"):
if line.startswith(" ") or line.startswith("\t"):
# Continuation line
if current_key:
current_value.append(line.strip())
elif ":" in line:
# Save previous field
if current_key:
metadata[current_key] = "\n".join(current_value)
# Parse new field
key, value = line.split(":", 1)
current_key = key.strip().lower()
current_value = [value.strip()]
else:
# Empty line or malformed
if current_key:
metadata[current_key] = "\n".join(current_value)
current_key = None
current_value = []
# Don't forget the last field
if current_key:
metadata[current_key] = "\n".join(current_value)
# Extract key fields
result = {}
if "package" in metadata:
result["package_name"] = metadata["package"]
if "version" in metadata:
result["version"] = metadata["version"]
if "architecture" in metadata:
result["architecture"] = metadata["architecture"]
if "maintainer" in metadata:
result["maintainer"] = metadata["maintainer"]
if "description" in metadata:
result["description"] = metadata["description"].split("\n")[0] # First line only
if "depends" in metadata:
result["depends"] = metadata["depends"]
result["format"] = "deb"
return result
def extract_rpm_metadata(file: BinaryIO) -> Dict[str, Any]:
"""
Extract metadata from an RPM package.
RPM files have a lead, signature, and header with metadata.
"""
metadata = {"format": "rpm"}
# Read RPM lead (96 bytes)
lead = file.read(96)
if len(lead) < 96:
return metadata
# Check magic number
if lead[0:4] != b"\xed\xab\xee\xdb":
return metadata
# Read name from lead (offset 10, max 66 bytes)
name_bytes = lead[10:76]
null_idx = name_bytes.find(b"\x00")
if null_idx > 0:
metadata["package_name"] = name_bytes[:null_idx].decode("ascii", errors="replace")
# Skip signature header to get to the main header
# This is complex - simplified version just extracts from lead
try:
# Skip to header
while True:
header_magic = file.read(8)
if len(header_magic) < 8:
break
if header_magic[0:3] == b"\x8e\xad\xe8":
# Found header magic
# Read header index count and data size
index_count = struct.unpack(">I", header_magic[4:8])[0]
data_size_bytes = file.read(4)
if len(data_size_bytes) < 4:
break
data_size = struct.unpack(">I", data_size_bytes)[0]
# Read header entries
entries = []
for _ in range(index_count):
entry = file.read(16)
if len(entry) < 16:
break
tag, type_, offset, count = struct.unpack(">IIII", entry)
entries.append((tag, type_, offset, count))
# Read header data
header_data = file.read(data_size)
# Extract relevant tags
# Tag 1000 = Name, Tag 1001 = Version, Tag 1002 = Release
# Tag 1004 = Summary, Tag 1022 = Arch
for tag, type_, offset, count in entries:
if type_ == 6: # STRING type
end = header_data.find(b"\x00", offset)
if end > offset:
value = header_data[offset:end].decode("utf-8", errors="replace")
if tag == 1000:
metadata["package_name"] = value
elif tag == 1001:
metadata["version"] = value
elif tag == 1002:
metadata["release"] = value
elif tag == 1004:
metadata["description"] = value
elif tag == 1022:
metadata["architecture"] = value
break
except Exception as e:
logger.debug(f"Failed to parse RPM header: {e}")
return metadata
def extract_tarball_metadata(file: BinaryIO, filename: str) -> Dict[str, Any]:
"""Extract metadata from a tarball (name and version from filename)"""
metadata = {"format": "tarball"}
# Try to extract name and version from filename
# Common patterns: package-1.0.0.tar.gz, package_1.0.0.tar.gz
basename = filename
for suffix in [".tar.gz", ".tgz", ".tar.bz2", ".tar.xz"]:
if basename.lower().endswith(suffix):
basename = basename[:-len(suffix)]
break
# Try to split name and version
patterns = [
r"^(.+)-(\d+\.\d+(?:\.\d+)?(?:[-._]\w+)?)$", # name-version
r"^(.+)_(\d+\.\d+(?:\.\d+)?(?:[-._]\w+)?)$", # name_version
]
for pattern in patterns:
match = re.match(pattern, basename)
if match:
metadata["package_name"] = match.group(1)
metadata["version"] = match.group(2)
break
return metadata
def extract_wheel_metadata(file: BinaryIO) -> Dict[str, Any]:
"""Extract metadata from a Python wheel (.whl) file"""
import zipfile
metadata = {"format": "wheel"}
try:
with zipfile.ZipFile(file, "r") as zf:
# Find METADATA file in .dist-info directory
for name in zf.namelist():
if name.endswith("/METADATA") and ".dist-info/" in name:
with zf.open(name) as f:
content = f.read().decode("utf-8", errors="replace")
# Parse email-style headers
for line in content.split("\n"):
if line.startswith("Name:"):
metadata["package_name"] = line[5:].strip()
elif line.startswith("Version:"):
metadata["version"] = line[8:].strip()
elif line.startswith("Summary:"):
metadata["description"] = line[8:].strip()
elif line.startswith("Author:"):
metadata["author"] = line[7:].strip()
elif line == "":
break # End of headers
break
except Exception as e:
logger.debug(f"Failed to parse wheel: {e}")
return metadata
def extract_jar_metadata(file: BinaryIO) -> Dict[str, Any]:
"""Extract metadata from a Java JAR file"""
import zipfile
metadata = {"format": "jar"}
try:
with zipfile.ZipFile(file, "r") as zf:
# Look for MANIFEST.MF
if "META-INF/MANIFEST.MF" in zf.namelist():
with zf.open("META-INF/MANIFEST.MF") as f:
content = f.read().decode("utf-8", errors="replace")
for line in content.split("\n"):
line = line.strip()
if line.startswith("Implementation-Title:"):
metadata["package_name"] = line[21:].strip()
elif line.startswith("Implementation-Version:"):
metadata["version"] = line[23:].strip()
elif line.startswith("Bundle-Name:"):
metadata["bundle_name"] = line[12:].strip()
elif line.startswith("Bundle-Version:"):
metadata["bundle_version"] = line[15:].strip()
# Also look for pom.properties in Maven JARs
for name in zf.namelist():
if name.endswith("/pom.properties"):
with zf.open(name) as f:
content = f.read().decode("utf-8", errors="replace")
for line in content.split("\n"):
if line.startswith("artifactId="):
metadata["artifact_id"] = line[11:].strip()
elif line.startswith("groupId="):
metadata["group_id"] = line[8:].strip()
elif line.startswith("version="):
if "version" not in metadata:
metadata["version"] = line[8:].strip()
break
except Exception as e:
logger.debug(f"Failed to parse JAR: {e}")
return metadata
def extract_zip_metadata(file: BinaryIO) -> Dict[str, Any]:
"""Extract basic metadata from a ZIP file"""
import zipfile
metadata = {"format": "zip"}
try:
with zipfile.ZipFile(file, "r") as zf:
metadata["file_count"] = len(zf.namelist())
# Calculate total uncompressed size
total_size = sum(info.file_size for info in zf.infolist())
metadata["uncompressed_size"] = total_size
except Exception as e:
logger.debug(f"Failed to parse ZIP: {e}")
return metadata

View File

@@ -64,6 +64,7 @@ class Artifact(Base):
created_by = Column(String(255), nullable=False)
ref_count = Column(Integer, default=1)
s3_key = Column(String(1024), nullable=False)
format_metadata = Column(JSON, default=dict) # Format-specific metadata (version, etc.)
tags = relationship("Tag", back_populates="artifact")
uploads = relationship("Upload", back_populates="artifact")

View File

@@ -1,12 +1,14 @@
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Form, Request
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Form, Request, Header, Response
from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session
from sqlalchemy import or_
from typing import List, Optional
import re
import io
import hashlib
from .database import get_db
from .storage import get_storage, S3Storage
from .storage import get_storage, S3Storage, MULTIPART_CHUNK_SIZE
from .models import Project, Package, Artifact, Tag, Upload, Consumer
from .schemas import (
ProjectCreate, ProjectResponse,
@@ -16,7 +18,14 @@ from .schemas import (
UploadResponse,
ConsumerResponse,
HealthResponse,
ResumableUploadInitRequest,
ResumableUploadInitResponse,
ResumableUploadPartResponse,
ResumableUploadCompleteRequest,
ResumableUploadCompleteResponse,
ResumableUploadStatusResponse,
)
from .metadata import extract_metadata
router = APIRouter()
@@ -118,6 +127,7 @@ def upload_artifact(
tag: Optional[str] = Form(None),
db: Session = Depends(get_db),
storage: S3Storage = Depends(get_storage),
content_length: Optional[int] = Header(None, alias="Content-Length"),
):
user_id = get_user_id(request)
@@ -130,13 +140,36 @@ def upload_artifact(
if not package:
raise HTTPException(status_code=404, detail="Package not found")
# Store file
sha256_hash, size, s3_key = storage.store(file.file)
# Extract format-specific metadata before storing
file_metadata = {}
if file.filename:
# Read file into memory for metadata extraction
file_content = file.file.read()
file.file.seek(0)
# Extract metadata
file_metadata = extract_metadata(
io.BytesIO(file_content),
file.filename,
file.content_type
)
# Store file (uses multipart for large files)
sha256_hash, size, s3_key = storage.store(file.file, content_length)
# Check if this is a deduplicated upload
deduplicated = False
# Create or update artifact record
artifact = db.query(Artifact).filter(Artifact.id == sha256_hash).first()
if artifact:
artifact.ref_count += 1
deduplicated = True
# Merge metadata if new metadata was extracted
if file_metadata and artifact.format_metadata:
artifact.format_metadata = {**artifact.format_metadata, **file_metadata}
elif file_metadata:
artifact.format_metadata = file_metadata
else:
artifact = Artifact(
id=sha256_hash,
@@ -145,6 +178,7 @@ def upload_artifact(
original_name=file.filename,
created_by=user_id,
s3_key=s3_key,
format_metadata=file_metadata or {},
)
db.add(artifact)
@@ -181,17 +215,265 @@ def upload_artifact(
project=project_name,
package=package_name,
tag=tag,
format_metadata=artifact.format_metadata,
deduplicated=deduplicated,
)
# Download artifact
# Resumable upload endpoints
@router.post("/api/v1/project/{project_name}/{package_name}/upload/init", response_model=ResumableUploadInitResponse)
def init_resumable_upload(
project_name: str,
package_name: str,
init_request: ResumableUploadInitRequest,
request: Request,
db: Session = Depends(get_db),
storage: S3Storage = Depends(get_storage),
):
"""
Initialize a resumable upload session.
Client must provide the SHA256 hash of the file in advance.
"""
user_id = get_user_id(request)
# Validate project and package
project = db.query(Project).filter(Project.name == project_name).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
package = db.query(Package).filter(Package.project_id == project.id, Package.name == package_name).first()
if not package:
raise HTTPException(status_code=404, detail="Package not found")
# Check if artifact already exists (deduplication)
existing_artifact = db.query(Artifact).filter(Artifact.id == init_request.expected_hash).first()
if existing_artifact:
# File already exists - increment ref count and return immediately
existing_artifact.ref_count += 1
# Record the upload
upload = Upload(
artifact_id=init_request.expected_hash,
package_id=package.id,
original_name=init_request.filename,
uploaded_by=user_id,
source_ip=request.client.host if request.client else None,
)
db.add(upload)
# Create tag if provided
if init_request.tag:
existing_tag = db.query(Tag).filter(
Tag.package_id == package.id, Tag.name == init_request.tag
).first()
if existing_tag:
existing_tag.artifact_id = init_request.expected_hash
existing_tag.created_by = user_id
else:
new_tag = Tag(
package_id=package.id,
name=init_request.tag,
artifact_id=init_request.expected_hash,
created_by=user_id,
)
db.add(new_tag)
db.commit()
return ResumableUploadInitResponse(
upload_id=None,
already_exists=True,
artifact_id=init_request.expected_hash,
chunk_size=MULTIPART_CHUNK_SIZE,
)
# Initialize resumable upload
session = storage.initiate_resumable_upload(init_request.expected_hash)
return ResumableUploadInitResponse(
upload_id=session["upload_id"],
already_exists=False,
artifact_id=None,
chunk_size=MULTIPART_CHUNK_SIZE,
)
@router.put("/api/v1/project/{project_name}/{package_name}/upload/{upload_id}/part/{part_number}")
def upload_part(
project_name: str,
package_name: str,
upload_id: str,
part_number: int,
request: Request,
db: Session = Depends(get_db),
storage: S3Storage = Depends(get_storage),
):
"""
Upload a part of a resumable upload.
Part numbers start at 1.
"""
# Validate project and package exist
project = db.query(Project).filter(Project.name == project_name).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
package = db.query(Package).filter(Package.project_id == project.id, Package.name == package_name).first()
if not package:
raise HTTPException(status_code=404, detail="Package not found")
if part_number < 1:
raise HTTPException(status_code=400, detail="Part number must be >= 1")
# Read part data from request body
import asyncio
loop = asyncio.new_event_loop()
async def read_body():
return await request.body()
try:
data = loop.run_until_complete(read_body())
finally:
loop.close()
if not data:
raise HTTPException(status_code=400, detail="No data in request body")
try:
part_info = storage.upload_part(upload_id, part_number, data)
return ResumableUploadPartResponse(
part_number=part_info["PartNumber"],
etag=part_info["ETag"],
)
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
@router.post("/api/v1/project/{project_name}/{package_name}/upload/{upload_id}/complete")
def complete_resumable_upload(
project_name: str,
package_name: str,
upload_id: str,
complete_request: ResumableUploadCompleteRequest,
request: Request,
db: Session = Depends(get_db),
storage: S3Storage = Depends(get_storage),
):
"""Complete a resumable upload"""
user_id = get_user_id(request)
# Validate project and package
project = db.query(Project).filter(Project.name == project_name).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
package = db.query(Package).filter(Package.project_id == project.id, Package.name == package_name).first()
if not package:
raise HTTPException(status_code=404, detail="Package not found")
try:
sha256_hash, s3_key = storage.complete_resumable_upload(upload_id)
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
# Get file size from S3
obj_info = storage.get_object_info(s3_key)
size = obj_info["size"] if obj_info else 0
# Create artifact record
artifact = Artifact(
id=sha256_hash,
size=size,
s3_key=s3_key,
created_by=user_id,
format_metadata={},
)
db.add(artifact)
# Record upload
upload = Upload(
artifact_id=sha256_hash,
package_id=package.id,
uploaded_by=user_id,
source_ip=request.client.host if request.client else None,
)
db.add(upload)
# Create tag if provided
if complete_request.tag:
existing_tag = db.query(Tag).filter(
Tag.package_id == package.id, Tag.name == complete_request.tag
).first()
if existing_tag:
existing_tag.artifact_id = sha256_hash
existing_tag.created_by = user_id
else:
new_tag = Tag(
package_id=package.id,
name=complete_request.tag,
artifact_id=sha256_hash,
created_by=user_id,
)
db.add(new_tag)
db.commit()
return ResumableUploadCompleteResponse(
artifact_id=sha256_hash,
size=size,
project=project_name,
package=package_name,
tag=complete_request.tag,
)
@router.delete("/api/v1/project/{project_name}/{package_name}/upload/{upload_id}")
def abort_resumable_upload(
project_name: str,
package_name: str,
upload_id: str,
storage: S3Storage = Depends(get_storage),
):
"""Abort a resumable upload"""
try:
storage.abort_resumable_upload(upload_id)
return {"status": "aborted"}
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
@router.get("/api/v1/project/{project_name}/{package_name}/upload/{upload_id}/status")
def get_upload_status(
project_name: str,
package_name: str,
upload_id: str,
storage: S3Storage = Depends(get_storage),
):
"""Get status of a resumable upload"""
try:
parts = storage.list_upload_parts(upload_id)
uploaded_parts = [p["PartNumber"] for p in parts]
total_bytes = sum(p.get("Size", 0) for p in parts)
return ResumableUploadStatusResponse(
upload_id=upload_id,
uploaded_parts=uploaded_parts,
total_uploaded_bytes=total_bytes,
)
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
# Download artifact with range request support
@router.get("/api/v1/project/{project_name}/{package_name}/+/{ref}")
def download_artifact(
project_name: str,
package_name: str,
ref: str,
request: Request,
db: Session = Depends(get_db),
storage: S3Storage = Depends(get_storage),
range: Optional[str] = Header(None),
):
# Get project and package
project = db.query(Project).filter(Project.name == project_name).first()
@@ -226,15 +508,90 @@ def download_artifact(
if not artifact:
raise HTTPException(status_code=404, detail="Artifact not found")
# Stream from S3
stream = storage.get_stream(artifact.s3_key)
filename = artifact.original_name or f"{artifact.id}"
# Handle range requests
if range:
stream, content_length, content_range = storage.get_stream(artifact.s3_key, range)
headers = {
"Content-Disposition": f'attachment; filename="{filename}"',
"Accept-Ranges": "bytes",
"Content-Length": str(content_length),
}
if content_range:
headers["Content-Range"] = content_range
return StreamingResponse(
stream,
status_code=206, # Partial Content
media_type=artifact.content_type or "application/octet-stream",
headers=headers,
)
# Full download
stream, content_length, _ = storage.get_stream(artifact.s3_key)
return StreamingResponse(
stream,
media_type=artifact.content_type or "application/octet-stream",
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
headers={
"Content-Disposition": f'attachment; filename="{filename}"',
"Accept-Ranges": "bytes",
"Content-Length": str(content_length),
},
)
# HEAD request for download (to check file info without downloading)
@router.head("/api/v1/project/{project_name}/{package_name}/+/{ref}")
def head_artifact(
project_name: str,
package_name: str,
ref: str,
db: Session = Depends(get_db),
storage: S3Storage = Depends(get_storage),
):
# Get project and package
project = db.query(Project).filter(Project.name == project_name).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
package = db.query(Package).filter(Package.project_id == project.id, Package.name == package_name).first()
if not package:
raise HTTPException(status_code=404, detail="Package not found")
# Resolve reference to artifact (same logic as download)
artifact = None
if ref.startswith("artifact:"):
artifact_id = ref[9:]
artifact = db.query(Artifact).filter(Artifact.id == artifact_id).first()
elif ref.startswith("tag:") or ref.startswith("version:"):
tag_name = ref.split(":", 1)[1]
tag = db.query(Tag).filter(Tag.package_id == package.id, Tag.name == tag_name).first()
if tag:
artifact = db.query(Artifact).filter(Artifact.id == tag.artifact_id).first()
else:
tag = db.query(Tag).filter(Tag.package_id == package.id, Tag.name == ref).first()
if tag:
artifact = db.query(Artifact).filter(Artifact.id == tag.artifact_id).first()
else:
artifact = db.query(Artifact).filter(Artifact.id == ref).first()
if not artifact:
raise HTTPException(status_code=404, detail="Artifact not found")
filename = artifact.original_name or f"{artifact.id}"
return Response(
content=b"",
media_type=artifact.content_type or "application/octet-stream",
headers={
"Content-Disposition": f'attachment; filename="{filename}"',
"Accept-Ranges": "bytes",
"Content-Length": str(artifact.size),
"X-Artifact-Id": artifact.id,
},
)
@@ -244,10 +601,12 @@ def download_artifact_compat(
project_name: str,
package_name: str,
ref: str,
request: Request,
db: Session = Depends(get_db),
storage: S3Storage = Depends(get_storage),
range: Optional[str] = Header(None),
):
return download_artifact(project_name, package_name, ref, db, storage)
return download_artifact(project_name, package_name, ref, request, db, storage, range)
# Tag routes

View File

@@ -1,5 +1,5 @@
from datetime import datetime
from typing import Optional, List
from typing import Optional, List, Dict, Any
from pydantic import BaseModel
from uuid import UUID
@@ -51,6 +51,7 @@ class ArtifactResponse(BaseModel):
created_at: datetime
created_by: str
ref_count: int
format_metadata: Optional[Dict[str, Any]] = None
class Config:
from_attributes = True
@@ -81,6 +82,53 @@ class UploadResponse(BaseModel):
project: str
package: str
tag: Optional[str]
format_metadata: Optional[Dict[str, Any]] = None
deduplicated: bool = False
# Resumable upload schemas
class ResumableUploadInitRequest(BaseModel):
"""Request to initiate a resumable upload"""
expected_hash: str # SHA256 hash of the file (client must compute)
filename: str
content_type: Optional[str] = None
size: int
tag: Optional[str] = None
class ResumableUploadInitResponse(BaseModel):
"""Response from initiating a resumable upload"""
upload_id: Optional[str] # None if file already exists
already_exists: bool
artifact_id: Optional[str] = None # Set if already_exists is True
chunk_size: int # Recommended chunk size for parts
class ResumableUploadPartResponse(BaseModel):
"""Response from uploading a part"""
part_number: int
etag: str
class ResumableUploadCompleteRequest(BaseModel):
"""Request to complete a resumable upload"""
tag: Optional[str] = None
class ResumableUploadCompleteResponse(BaseModel):
"""Response from completing a resumable upload"""
artifact_id: str
size: int
project: str
package: str
tag: Optional[str]
class ResumableUploadStatusResponse(BaseModel):
"""Status of a resumable upload"""
upload_id: str
uploaded_parts: List[int]
total_uploaded_bytes: int
# Consumer schemas

View File

@@ -1,5 +1,6 @@
import hashlib
from typing import BinaryIO, Tuple
import logging
from typing import BinaryIO, Tuple, Optional, Dict, Any, Generator
import boto3
from botocore.config import Config
from botocore.exceptions import ClientError
@@ -7,6 +8,14 @@ from botocore.exceptions import ClientError
from .config import get_settings
settings = get_settings()
logger = logging.getLogger(__name__)
# Threshold for multipart upload (100MB)
MULTIPART_THRESHOLD = 100 * 1024 * 1024
# Chunk size for multipart upload (10MB)
MULTIPART_CHUNK_SIZE = 10 * 1024 * 1024
# Chunk size for streaming hash computation
HASH_CHUNK_SIZE = 8 * 1024 * 1024
class S3Storage:
@@ -22,12 +31,23 @@ class S3Storage:
config=config,
)
self.bucket = settings.s3_bucket
# Store active multipart uploads for resumable support
self._active_uploads: Dict[str, Dict[str, Any]] = {}
def store(self, file: BinaryIO) -> Tuple[str, int]:
def store(self, file: BinaryIO, content_length: Optional[int] = None) -> Tuple[str, int, str]:
"""
Store a file and return its SHA256 hash and size.
Store a file and return its SHA256 hash, size, and s3_key.
Content-addressable: if the file already exists, just return the hash.
Uses multipart upload for files larger than MULTIPART_THRESHOLD.
"""
# For small files or unknown size, use the simple approach
if content_length is None or content_length < MULTIPART_THRESHOLD:
return self._store_simple(file)
else:
return self._store_multipart(file, content_length)
def _store_simple(self, file: BinaryIO) -> Tuple[str, int, str]:
"""Store a small file using simple put_object"""
# Read file and compute hash
content = file.read()
sha256_hash = hashlib.sha256(content).hexdigest()
@@ -45,15 +65,300 @@ class S3Storage:
return sha256_hash, size, s3_key
def _store_multipart(self, file: BinaryIO, content_length: int) -> Tuple[str, int, str]:
"""Store a large file using S3 multipart upload with streaming hash computation"""
# First pass: compute hash by streaming through file
hasher = hashlib.sha256()
size = 0
# Read file in chunks to compute hash
while True:
chunk = file.read(HASH_CHUNK_SIZE)
if not chunk:
break
hasher.update(chunk)
size += len(chunk)
sha256_hash = hasher.hexdigest()
s3_key = f"fruits/{sha256_hash[:2]}/{sha256_hash[2:4]}/{sha256_hash}"
# Check if already exists (deduplication)
if self._exists(s3_key):
return sha256_hash, size, s3_key
# Seek back to start for upload
file.seek(0)
# Start multipart upload
mpu = self.client.create_multipart_upload(Bucket=self.bucket, Key=s3_key)
upload_id = mpu["UploadId"]
try:
parts = []
part_number = 1
while True:
chunk = file.read(MULTIPART_CHUNK_SIZE)
if not chunk:
break
response = self.client.upload_part(
Bucket=self.bucket,
Key=s3_key,
UploadId=upload_id,
PartNumber=part_number,
Body=chunk,
)
parts.append({
"PartNumber": part_number,
"ETag": response["ETag"],
})
part_number += 1
# Complete multipart upload
self.client.complete_multipart_upload(
Bucket=self.bucket,
Key=s3_key,
UploadId=upload_id,
MultipartUpload={"Parts": parts},
)
return sha256_hash, size, s3_key
except Exception as e:
# Abort multipart upload on failure
logger.error(f"Multipart upload failed: {e}")
self.client.abort_multipart_upload(
Bucket=self.bucket,
Key=s3_key,
UploadId=upload_id,
)
raise
def store_streaming(self, chunks: Generator[bytes, None, None]) -> Tuple[str, int, str]:
"""
Store a file from a stream of chunks.
First accumulates to compute hash, then uploads.
For truly large files, consider using initiate_resumable_upload instead.
"""
# Accumulate chunks and compute hash
hasher = hashlib.sha256()
all_chunks = []
size = 0
for chunk in chunks:
hasher.update(chunk)
all_chunks.append(chunk)
size += len(chunk)
sha256_hash = hasher.hexdigest()
s3_key = f"fruits/{sha256_hash[:2]}/{sha256_hash[2:4]}/{sha256_hash}"
# Check if already exists
if self._exists(s3_key):
return sha256_hash, size, s3_key
# Upload based on size
if size < MULTIPART_THRESHOLD:
content = b"".join(all_chunks)
self.client.put_object(Bucket=self.bucket, Key=s3_key, Body=content)
else:
# Use multipart for large files
mpu = self.client.create_multipart_upload(Bucket=self.bucket, Key=s3_key)
upload_id = mpu["UploadId"]
try:
parts = []
part_number = 1
buffer = b""
for chunk in all_chunks:
buffer += chunk
while len(buffer) >= MULTIPART_CHUNK_SIZE:
part_data = buffer[:MULTIPART_CHUNK_SIZE]
buffer = buffer[MULTIPART_CHUNK_SIZE:]
response = self.client.upload_part(
Bucket=self.bucket,
Key=s3_key,
UploadId=upload_id,
PartNumber=part_number,
Body=part_data,
)
parts.append({
"PartNumber": part_number,
"ETag": response["ETag"],
})
part_number += 1
# Upload remaining buffer
if buffer:
response = self.client.upload_part(
Bucket=self.bucket,
Key=s3_key,
UploadId=upload_id,
PartNumber=part_number,
Body=buffer,
)
parts.append({
"PartNumber": part_number,
"ETag": response["ETag"],
})
self.client.complete_multipart_upload(
Bucket=self.bucket,
Key=s3_key,
UploadId=upload_id,
MultipartUpload={"Parts": parts},
)
except Exception as e:
logger.error(f"Streaming multipart upload failed: {e}")
self.client.abort_multipart_upload(
Bucket=self.bucket,
Key=s3_key,
UploadId=upload_id,
)
raise
return sha256_hash, size, s3_key
def initiate_resumable_upload(self, expected_hash: str) -> Dict[str, Any]:
"""
Initiate a resumable upload session.
Returns upload session info including upload_id.
"""
s3_key = f"fruits/{expected_hash[:2]}/{expected_hash[2:4]}/{expected_hash}"
# Check if already exists
if self._exists(s3_key):
return {
"upload_id": None,
"s3_key": s3_key,
"already_exists": True,
"parts": [],
}
mpu = self.client.create_multipart_upload(Bucket=self.bucket, Key=s3_key)
upload_id = mpu["UploadId"]
session = {
"upload_id": upload_id,
"s3_key": s3_key,
"already_exists": False,
"parts": [],
"expected_hash": expected_hash,
}
self._active_uploads[upload_id] = session
return session
def upload_part(self, upload_id: str, part_number: int, data: bytes) -> Dict[str, Any]:
"""
Upload a part for a resumable upload.
Returns part info including ETag.
"""
session = self._active_uploads.get(upload_id)
if not session:
raise ValueError(f"Unknown upload session: {upload_id}")
response = self.client.upload_part(
Bucket=self.bucket,
Key=session["s3_key"],
UploadId=upload_id,
PartNumber=part_number,
Body=data,
)
part_info = {
"PartNumber": part_number,
"ETag": response["ETag"],
}
session["parts"].append(part_info)
return part_info
def complete_resumable_upload(self, upload_id: str) -> Tuple[str, str]:
"""
Complete a resumable upload.
Returns (sha256_hash, s3_key).
"""
session = self._active_uploads.get(upload_id)
if not session:
raise ValueError(f"Unknown upload session: {upload_id}")
# Sort parts by part number
sorted_parts = sorted(session["parts"], key=lambda x: x["PartNumber"])
self.client.complete_multipart_upload(
Bucket=self.bucket,
Key=session["s3_key"],
UploadId=upload_id,
MultipartUpload={"Parts": sorted_parts},
)
# Clean up session
del self._active_uploads[upload_id]
return session["expected_hash"], session["s3_key"]
def abort_resumable_upload(self, upload_id: str):
"""Abort a resumable upload"""
session = self._active_uploads.get(upload_id)
if session:
self.client.abort_multipart_upload(
Bucket=self.bucket,
Key=session["s3_key"],
UploadId=upload_id,
)
del self._active_uploads[upload_id]
def list_upload_parts(self, upload_id: str) -> list:
"""List uploaded parts for a resumable upload (for resume support)"""
session = self._active_uploads.get(upload_id)
if not session:
raise ValueError(f"Unknown upload session: {upload_id}")
response = self.client.list_parts(
Bucket=self.bucket,
Key=session["s3_key"],
UploadId=upload_id,
)
return response.get("Parts", [])
def get(self, s3_key: str) -> bytes:
"""Retrieve a file by its S3 key"""
response = self.client.get_object(Bucket=self.bucket, Key=s3_key)
return response["Body"].read()
def get_stream(self, s3_key: str):
"""Get a streaming response for a file"""
response = self.client.get_object(Bucket=self.bucket, Key=s3_key)
return response["Body"]
def get_stream(self, s3_key: str, range_header: Optional[str] = None):
"""
Get a streaming response for a file.
Supports range requests for partial downloads.
Returns (stream, content_length, content_range, accept_ranges)
"""
kwargs = {"Bucket": self.bucket, "Key": s3_key}
if range_header:
kwargs["Range"] = range_header
response = self.client.get_object(**kwargs)
content_length = response.get("ContentLength", 0)
content_range = response.get("ContentRange")
return response["Body"], content_length, content_range
def get_object_info(self, s3_key: str) -> Dict[str, Any]:
"""Get object metadata without downloading content"""
try:
response = self.client.head_object(Bucket=self.bucket, Key=s3_key)
return {
"size": response.get("ContentLength", 0),
"content_type": response.get("ContentType"),
"last_modified": response.get("LastModified"),
"etag": response.get("ETag"),
}
except ClientError:
return None
def _exists(self, s3_key: str) -> bool:
"""Check if an object exists"""