Security fix and test reorganization

- Add sanitize_filename() to prevent Content-Disposition header injection
- Remove unused imports from models.py and artifact_cleanup.py
- Reorganize tests into unit/ and integration/ structure
- Add factories.py for test data generation
- Split old test files into focused test modules (143 tests)
This commit is contained in:
Mondo Diaz
2026-01-06 15:04:51 -06:00
parent a293432d2e
commit b81c69118f
20 changed files with 3007 additions and 2626 deletions

View File

@@ -1,8 +1,16 @@
from datetime import datetime
from typing import Optional
from sqlalchemy import (
Column, String, Text, Boolean, Integer, BigInteger,
DateTime, ForeignKey, CheckConstraint, Index, JSON
Column,
String,
Text,
Boolean,
Integer,
BigInteger,
DateTime,
ForeignKey,
CheckConstraint,
Index,
JSON,
)
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import relationship, declarative_base
@@ -19,11 +27,17 @@ class Project(Base):
description = Column(Text)
is_public = Column(Boolean, default=True)
created_at = Column(DateTime(timezone=True), default=datetime.utcnow)
updated_at = Column(DateTime(timezone=True), default=datetime.utcnow, onupdate=datetime.utcnow)
updated_at = Column(
DateTime(timezone=True), default=datetime.utcnow, onupdate=datetime.utcnow
)
created_by = Column(String(255), nullable=False)
packages = relationship("Package", back_populates="project", cascade="all, delete-orphan")
permissions = relationship("AccessPermission", back_populates="project", cascade="all, delete-orphan")
packages = relationship(
"Package", back_populates="project", cascade="all, delete-orphan"
)
permissions = relationship(
"AccessPermission", back_populates="project", cascade="all, delete-orphan"
)
__table_args__ = (
Index("idx_projects_name", "name"),
@@ -35,32 +49,44 @@ class Package(Base):
__tablename__ = "packages"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
project_id = Column(UUID(as_uuid=True), ForeignKey("projects.id", ondelete="CASCADE"), nullable=False)
project_id = Column(
UUID(as_uuid=True),
ForeignKey("projects.id", ondelete="CASCADE"),
nullable=False,
)
name = Column(String(255), nullable=False)
description = Column(Text)
format = Column(String(50), default="generic", nullable=False)
platform = Column(String(50), default="any", nullable=False)
created_at = Column(DateTime(timezone=True), default=datetime.utcnow)
updated_at = Column(DateTime(timezone=True), default=datetime.utcnow, onupdate=datetime.utcnow)
updated_at = Column(
DateTime(timezone=True), default=datetime.utcnow, onupdate=datetime.utcnow
)
project = relationship("Project", back_populates="packages")
tags = relationship("Tag", back_populates="package", cascade="all, delete-orphan")
uploads = relationship("Upload", back_populates="package", cascade="all, delete-orphan")
consumers = relationship("Consumer", back_populates="package", cascade="all, delete-orphan")
uploads = relationship(
"Upload", back_populates="package", cascade="all, delete-orphan"
)
consumers = relationship(
"Consumer", back_populates="package", cascade="all, delete-orphan"
)
__table_args__ = (
Index("idx_packages_project_id", "project_id"),
Index("idx_packages_name", "name"),
Index("idx_packages_format", "format"),
Index("idx_packages_platform", "platform"),
Index("idx_packages_project_name", "project_id", "name", unique=True), # Composite unique index
Index(
"idx_packages_project_name", "project_id", "name", unique=True
), # Composite unique index
CheckConstraint(
"format IN ('generic', 'npm', 'pypi', 'docker', 'deb', 'rpm', 'maven', 'nuget', 'helm')",
name="check_package_format"
name="check_package_format",
),
CheckConstraint(
"platform IN ('any', 'linux', 'darwin', 'windows', 'linux-amd64', 'linux-arm64', 'darwin-amd64', 'darwin-arm64', 'windows-amd64')",
name="check_package_platform"
name="check_package_platform",
),
{"extend_existing": True},
)
@@ -76,7 +102,9 @@ class Artifact(Base):
checksum_md5 = Column(String(32)) # MD5 hash for additional verification
checksum_sha1 = Column(String(40)) # SHA1 hash for compatibility
s3_etag = Column(String(64)) # S3 ETag for verification
artifact_metadata = Column("metadata", JSON, default=dict) # Format-specific metadata (column name is 'metadata')
artifact_metadata = Column(
"metadata", JSON, default=dict
) # Format-specific metadata (column name is 'metadata')
created_at = Column(DateTime(timezone=True), default=datetime.utcnow)
created_by = Column(String(255), nullable=False)
ref_count = Column(Integer, default=1)
@@ -113,22 +141,34 @@ class Tag(Base):
__tablename__ = "tags"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
package_id = Column(UUID(as_uuid=True), ForeignKey("packages.id", ondelete="CASCADE"), nullable=False)
package_id = Column(
UUID(as_uuid=True),
ForeignKey("packages.id", ondelete="CASCADE"),
nullable=False,
)
name = Column(String(255), nullable=False)
artifact_id = Column(String(64), ForeignKey("artifacts.id"), nullable=False)
created_at = Column(DateTime(timezone=True), default=datetime.utcnow)
updated_at = Column(DateTime(timezone=True), default=datetime.utcnow, onupdate=datetime.utcnow)
updated_at = Column(
DateTime(timezone=True), default=datetime.utcnow, onupdate=datetime.utcnow
)
created_by = Column(String(255), nullable=False)
package = relationship("Package", back_populates="tags")
artifact = relationship("Artifact", back_populates="tags")
history = relationship("TagHistory", back_populates="tag", cascade="all, delete-orphan")
history = relationship(
"TagHistory", back_populates="tag", cascade="all, delete-orphan"
)
__table_args__ = (
Index("idx_tags_package_id", "package_id"),
Index("idx_tags_artifact_id", "artifact_id"),
Index("idx_tags_package_name", "package_id", "name", unique=True), # Composite unique index
Index("idx_tags_package_created_at", "package_id", "created_at"), # For recent tags queries
Index(
"idx_tags_package_name", "package_id", "name", unique=True
), # Composite unique index
Index(
"idx_tags_package_created_at", "package_id", "created_at"
), # For recent tags queries
)
@@ -136,7 +176,9 @@ class TagHistory(Base):
__tablename__ = "tag_history"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
tag_id = Column(UUID(as_uuid=True), ForeignKey("tags.id", ondelete="CASCADE"), nullable=False)
tag_id = Column(
UUID(as_uuid=True), ForeignKey("tags.id", ondelete="CASCADE"), nullable=False
)
old_artifact_id = Column(String(64), ForeignKey("artifacts.id"))
new_artifact_id = Column(String(64), ForeignKey("artifacts.id"), nullable=False)
change_type = Column(String(20), nullable=False, default="update")
@@ -148,7 +190,9 @@ class TagHistory(Base):
__table_args__ = (
Index("idx_tag_history_tag_id", "tag_id"),
Index("idx_tag_history_changed_at", "changed_at"),
CheckConstraint("change_type IN ('create', 'update', 'delete')", name="check_change_type"),
CheckConstraint(
"change_type IN ('create', 'update', 'delete')", name="check_change_type"
),
)
@@ -184,7 +228,11 @@ class Consumer(Base):
__tablename__ = "consumers"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
package_id = Column(UUID(as_uuid=True), ForeignKey("packages.id", ondelete="CASCADE"), nullable=False)
package_id = Column(
UUID(as_uuid=True),
ForeignKey("packages.id", ondelete="CASCADE"),
nullable=False,
)
project_url = Column(String(2048), nullable=False)
last_access = Column(DateTime(timezone=True), default=datetime.utcnow)
created_at = Column(DateTime(timezone=True), default=datetime.utcnow)
@@ -201,7 +249,11 @@ class AccessPermission(Base):
__tablename__ = "access_permissions"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
project_id = Column(UUID(as_uuid=True), ForeignKey("projects.id", ondelete="CASCADE"), nullable=False)
project_id = Column(
UUID(as_uuid=True),
ForeignKey("projects.id", ondelete="CASCADE"),
nullable=False,
)
user_id = Column(String(255), nullable=False)
level = Column(String(20), nullable=False)
created_at = Column(DateTime(timezone=True), default=datetime.utcnow)

View File

@@ -95,6 +95,18 @@ from .config import get_settings
router = APIRouter()
def sanitize_filename(filename: str) -> str:
"""Sanitize filename for use in Content-Disposition header.
Removes characters that could enable header injection attacks:
- Double quotes (") - could break out of quoted filename
- Carriage return (\\r) and newline (\\n) - could inject headers
"""
import re
return re.sub(r'[\r\n"]', "", filename)
def get_user_id(request: Request) -> str:
"""Extract user ID from request (simplified for now)"""
api_key = request.headers.get("X-Orchard-API-Key")
@@ -1553,7 +1565,7 @@ def download_artifact(
if not artifact:
raise HTTPException(status_code=404, detail="Artifact not found")
filename = artifact.original_name or f"{artifact.id}"
filename = sanitize_filename(artifact.original_name or f"{artifact.id}")
# Determine download mode (query param overrides server default)
download_mode = mode or settings.download_mode
@@ -1666,7 +1678,7 @@ def get_artifact_url(
if not artifact:
raise HTTPException(status_code=404, detail="Artifact not found")
filename = artifact.original_name or f"{artifact.id}"
filename = sanitize_filename(artifact.original_name or f"{artifact.id}")
url_expiry = expiry or settings.presigned_url_expiry
presigned_url = storage.generate_presigned_url(
@@ -1717,7 +1729,7 @@ def head_artifact(
if not artifact:
raise HTTPException(status_code=404, detail="Artifact not found")
filename = artifact.original_name or f"{artifact.id}"
filename = sanitize_filename(artifact.original_name or f"{artifact.id}")
return Response(
content=b"",

View File

@@ -6,7 +6,7 @@ from typing import List, Optional, Tuple
from sqlalchemy.orm import Session
import logging
from ..models import Artifact, Tag, Upload, Package
from ..models import Artifact, Tag
from ..repositories.artifact import ArtifactRepository
from ..repositories.tag import TagRepository
from ..storage import S3Storage
@@ -40,10 +40,14 @@ class ArtifactCleanupService:
artifact = self.artifact_repo.get_by_sha256(artifact_id)
if artifact:
artifact = self.artifact_repo.decrement_ref_count(artifact)
logger.info(f"Decremented ref_count for artifact {artifact_id}: now {artifact.ref_count}")
logger.info(
f"Decremented ref_count for artifact {artifact_id}: now {artifact.ref_count}"
)
return artifact
def on_tag_updated(self, old_artifact_id: str, new_artifact_id: str) -> Tuple[Optional[Artifact], Optional[Artifact]]:
def on_tag_updated(
self, old_artifact_id: str, new_artifact_id: str
) -> Tuple[Optional[Artifact], Optional[Artifact]]:
"""
Called when a tag is updated to point to a different artifact.
Decrements ref_count for old artifact, increments for new (if different).
@@ -58,13 +62,17 @@ class ArtifactCleanupService:
old_artifact = self.artifact_repo.get_by_sha256(old_artifact_id)
if old_artifact:
old_artifact = self.artifact_repo.decrement_ref_count(old_artifact)
logger.info(f"Decremented ref_count for old artifact {old_artifact_id}: now {old_artifact.ref_count}")
logger.info(
f"Decremented ref_count for old artifact {old_artifact_id}: now {old_artifact.ref_count}"
)
# Increment new artifact ref_count
new_artifact = self.artifact_repo.get_by_sha256(new_artifact_id)
if new_artifact:
new_artifact = self.artifact_repo.increment_ref_count(new_artifact)
logger.info(f"Incremented ref_count for new artifact {new_artifact_id}: now {new_artifact.ref_count}")
logger.info(
f"Incremented ref_count for new artifact {new_artifact_id}: now {new_artifact.ref_count}"
)
return old_artifact, new_artifact
@@ -84,11 +92,15 @@ class ArtifactCleanupService:
if artifact:
self.artifact_repo.decrement_ref_count(artifact)
affected_artifacts.append(tag.artifact_id)
logger.info(f"Decremented ref_count for artifact {tag.artifact_id} (package delete)")
logger.info(
f"Decremented ref_count for artifact {tag.artifact_id} (package delete)"
)
return affected_artifacts
def cleanup_orphaned_artifacts(self, batch_size: int = 100, dry_run: bool = False) -> List[str]:
def cleanup_orphaned_artifacts(
self, batch_size: int = 100, dry_run: bool = False
) -> List[str]:
"""
Find and delete artifacts with ref_count = 0.
@@ -116,7 +128,9 @@ class ArtifactCleanupService:
# Then delete from database
self.artifact_repo.delete(artifact)
deleted_ids.append(artifact.id)
logger.info(f"Deleted orphaned artifact from database: {artifact.id}")
logger.info(
f"Deleted orphaned artifact from database: {artifact.id}"
)
except Exception as e:
logger.error(f"Failed to delete artifact {artifact.id}: {e}")
@@ -128,10 +142,12 @@ class ArtifactCleanupService:
def get_orphaned_count(self) -> int:
"""Get count of artifacts with ref_count = 0."""
from sqlalchemy import func
return (
self.db.query(func.count(Artifact.id))
.filter(Artifact.ref_count == 0)
.scalar() or 0
.scalar()
or 0
)
def verify_ref_counts(self, fix: bool = False) -> List[dict]:
@@ -173,7 +189,9 @@ class ArtifactCleanupService:
if fix:
artifact.ref_count = max(actual_count, 1)
logger.warning(f"Fixed ref_count for artifact {artifact.id}: {mismatch['stored_ref_count']} -> {artifact.ref_count}")
logger.warning(
f"Fixed ref_count for artifact {artifact.id}: {mismatch['stored_ref_count']} -> {artifact.ref_count}"
)
if fix and mismatches:
self.db.commit()