Add upstream caching infrastructure and refactor CI pipeline
This commit is contained in:
316
backend/app/cache.py
Normal file
316
backend/app/cache.py
Normal file
@@ -0,0 +1,316 @@
|
||||
"""
|
||||
Cache service for upstream artifact caching.
|
||||
|
||||
Provides URL parsing, system project management, and caching logic
|
||||
for the upstream caching feature.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
from urllib.parse import urlparse, unquote
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# System project names for each source type
|
||||
SYSTEM_PROJECT_NAMES = {
|
||||
"npm": "_npm",
|
||||
"pypi": "_pypi",
|
||||
"maven": "_maven",
|
||||
"docker": "_docker",
|
||||
"helm": "_helm",
|
||||
"nuget": "_nuget",
|
||||
"deb": "_deb",
|
||||
"rpm": "_rpm",
|
||||
"generic": "_generic",
|
||||
}
|
||||
|
||||
# System project descriptions
|
||||
SYSTEM_PROJECT_DESCRIPTIONS = {
|
||||
"npm": "System cache for npm packages",
|
||||
"pypi": "System cache for PyPI packages",
|
||||
"maven": "System cache for Maven packages",
|
||||
"docker": "System cache for Docker images",
|
||||
"helm": "System cache for Helm charts",
|
||||
"nuget": "System cache for NuGet packages",
|
||||
"deb": "System cache for Debian packages",
|
||||
"rpm": "System cache for RPM packages",
|
||||
"generic": "System cache for generic artifacts",
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParsedUrl:
|
||||
"""Parsed URL information for caching."""
|
||||
|
||||
package_name: str
|
||||
version: Optional[str] = None
|
||||
filename: Optional[str] = None
|
||||
|
||||
|
||||
def parse_npm_url(url: str) -> Optional[ParsedUrl]:
|
||||
"""
|
||||
Parse npm registry URL to extract package name and version.
|
||||
|
||||
Formats:
|
||||
- https://registry.npmjs.org/{package}/-/{package}-{version}.tgz
|
||||
- https://registry.npmjs.org/@{scope}/{package}/-/{package}-{version}.tgz
|
||||
|
||||
Examples:
|
||||
- https://registry.npmjs.org/lodash/-/lodash-4.17.21.tgz
|
||||
- https://registry.npmjs.org/@types/node/-/node-18.0.0.tgz
|
||||
"""
|
||||
parsed = urlparse(url)
|
||||
path = unquote(parsed.path)
|
||||
|
||||
# Pattern for scoped packages: /@scope/package/-/package-version.tgz
|
||||
scoped_pattern = r"^/@([^/]+)/([^/]+)/-/\2-(.+)\.tgz$"
|
||||
match = re.match(scoped_pattern, path)
|
||||
if match:
|
||||
scope, name, version = match.groups()
|
||||
return ParsedUrl(
|
||||
package_name=f"@{scope}/{name}",
|
||||
version=version,
|
||||
filename=f"{name}-{version}.tgz",
|
||||
)
|
||||
|
||||
# Pattern for unscoped packages: /package/-/package-version.tgz
|
||||
unscoped_pattern = r"^/([^/@]+)/-/\1-(.+)\.tgz$"
|
||||
match = re.match(unscoped_pattern, path)
|
||||
if match:
|
||||
name, version = match.groups()
|
||||
return ParsedUrl(
|
||||
package_name=name,
|
||||
version=version,
|
||||
filename=f"{name}-{version}.tgz",
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def parse_pypi_url(url: str) -> Optional[ParsedUrl]:
|
||||
"""
|
||||
Parse PyPI URL to extract package name and version.
|
||||
|
||||
Formats:
|
||||
- https://files.pythonhosted.org/packages/.../package-version.tar.gz
|
||||
- https://files.pythonhosted.org/packages/.../package-version-py3-none-any.whl
|
||||
- https://pypi.org/packages/.../package-version.tar.gz
|
||||
|
||||
Examples:
|
||||
- https://files.pythonhosted.org/packages/ab/cd/requests-2.28.0.tar.gz
|
||||
- https://files.pythonhosted.org/packages/ab/cd/requests-2.28.0-py3-none-any.whl
|
||||
"""
|
||||
parsed = urlparse(url)
|
||||
path = unquote(parsed.path)
|
||||
|
||||
# Get the filename from the path
|
||||
filename = path.split("/")[-1]
|
||||
if not filename:
|
||||
return None
|
||||
|
||||
# Handle wheel files: package-version-py3-none-any.whl
|
||||
wheel_pattern = r"^([a-zA-Z0-9_-]+)-(\d+[^-]*)-.*\.whl$"
|
||||
match = re.match(wheel_pattern, filename)
|
||||
if match:
|
||||
name, version = match.groups()
|
||||
# Normalize package name (PyPI uses underscores internally)
|
||||
name = name.replace("_", "-").lower()
|
||||
return ParsedUrl(
|
||||
package_name=name,
|
||||
version=version,
|
||||
filename=filename,
|
||||
)
|
||||
|
||||
# Handle source distributions: package-version.tar.gz or package-version.zip
|
||||
sdist_pattern = r"^([a-zA-Z0-9_-]+)-(\d+(?:\.\d+)*(?:[a-zA-Z0-9_.+-]*)?)(?:\.tar\.gz|\.zip|\.tar\.bz2)$"
|
||||
match = re.match(sdist_pattern, filename)
|
||||
if match:
|
||||
name, version = match.groups()
|
||||
name = name.replace("_", "-").lower()
|
||||
return ParsedUrl(
|
||||
package_name=name,
|
||||
version=version,
|
||||
filename=filename,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def parse_maven_url(url: str) -> Optional[ParsedUrl]:
|
||||
"""
|
||||
Parse Maven repository URL to extract artifact info.
|
||||
|
||||
Format:
|
||||
- https://repo1.maven.org/maven2/{group}/{artifact}/{version}/{artifact}-{version}.jar
|
||||
|
||||
Examples:
|
||||
- https://repo1.maven.org/maven2/org/apache/commons/commons-lang3/3.12.0/commons-lang3-3.12.0.jar
|
||||
- https://repo1.maven.org/maven2/com/google/guava/guava/31.1-jre/guava-31.1-jre.jar
|
||||
"""
|
||||
parsed = urlparse(url)
|
||||
path = unquote(parsed.path)
|
||||
|
||||
# Find /maven2/ or similar repository path
|
||||
maven2_idx = path.find("/maven2/")
|
||||
if maven2_idx >= 0:
|
||||
path = path[maven2_idx + 8:] # Remove /maven2/
|
||||
elif path.startswith("/"):
|
||||
path = path[1:]
|
||||
|
||||
parts = path.split("/")
|
||||
if len(parts) < 4:
|
||||
return None
|
||||
|
||||
# Last part is filename, before that is version, before that is artifact
|
||||
filename = parts[-1]
|
||||
version = parts[-2]
|
||||
artifact = parts[-3]
|
||||
group = ".".join(parts[:-3])
|
||||
|
||||
# Verify filename matches expected pattern
|
||||
if not filename.startswith(f"{artifact}-{version}"):
|
||||
return None
|
||||
|
||||
return ParsedUrl(
|
||||
package_name=f"{group}:{artifact}",
|
||||
version=version,
|
||||
filename=filename,
|
||||
)
|
||||
|
||||
|
||||
def parse_docker_url(url: str) -> Optional[ParsedUrl]:
|
||||
"""
|
||||
Parse Docker registry URL to extract image info.
|
||||
|
||||
Note: Docker registries are more complex (manifests, blobs, etc.)
|
||||
This handles basic blob/manifest URLs.
|
||||
|
||||
Examples:
|
||||
- https://registry-1.docker.io/v2/library/nginx/blobs/sha256:abc123
|
||||
- https://registry-1.docker.io/v2/myuser/myimage/manifests/latest
|
||||
"""
|
||||
parsed = urlparse(url)
|
||||
path = unquote(parsed.path)
|
||||
|
||||
# Pattern: /v2/{namespace}/{image}/blobs/{digest} or /manifests/{tag}
|
||||
pattern = r"^/v2/([^/]+(?:/[^/]+)?)/([^/]+)/(blobs|manifests)/(.+)$"
|
||||
match = re.match(pattern, path)
|
||||
if match:
|
||||
namespace, image, artifact_type, reference = match.groups()
|
||||
if namespace == "library":
|
||||
package_name = image
|
||||
else:
|
||||
package_name = f"{namespace}/{image}"
|
||||
|
||||
# For manifests, the reference is the tag
|
||||
version = reference if artifact_type == "manifests" else None
|
||||
|
||||
return ParsedUrl(
|
||||
package_name=package_name,
|
||||
version=version,
|
||||
filename=f"{image}-{reference}" if version else reference,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def parse_generic_url(url: str) -> ParsedUrl:
|
||||
"""
|
||||
Parse a generic URL to extract filename.
|
||||
|
||||
Attempts to extract meaningful package name and version from filename.
|
||||
|
||||
Examples:
|
||||
- https://example.com/downloads/myapp-1.2.3.tar.gz
|
||||
- https://github.com/user/repo/releases/download/v1.0/release.zip
|
||||
"""
|
||||
parsed = urlparse(url)
|
||||
path = unquote(parsed.path)
|
||||
filename = path.split("/")[-1] or "artifact"
|
||||
|
||||
# List of known compound and simple extensions
|
||||
known_extensions = [
|
||||
".tar.gz", ".tar.bz2", ".tar.xz",
|
||||
".zip", ".tgz", ".gz", ".jar", ".war", ".deb", ".rpm"
|
||||
]
|
||||
|
||||
# Strip extension from filename first
|
||||
base_name = filename
|
||||
matched_ext = None
|
||||
for ext in known_extensions:
|
||||
if filename.endswith(ext):
|
||||
base_name = filename[:-len(ext)]
|
||||
matched_ext = ext
|
||||
break
|
||||
|
||||
if matched_ext is None:
|
||||
# Unknown extension, return filename as package name
|
||||
return ParsedUrl(
|
||||
package_name=filename,
|
||||
version=None,
|
||||
filename=filename,
|
||||
)
|
||||
|
||||
# Try to extract version from base_name
|
||||
# Pattern: name-version or name_version
|
||||
# Version starts with digit(s) and can include dots, dashes, and alphanumeric suffixes
|
||||
version_pattern = r"^(.+?)[-_](v?\d+(?:\.\d+)*(?:[-_][a-zA-Z0-9]+)?)$"
|
||||
match = re.match(version_pattern, base_name)
|
||||
if match:
|
||||
name, version = match.groups()
|
||||
return ParsedUrl(
|
||||
package_name=name,
|
||||
version=version,
|
||||
filename=filename,
|
||||
)
|
||||
|
||||
# No version found, use base_name as package name
|
||||
return ParsedUrl(
|
||||
package_name=base_name,
|
||||
version=None,
|
||||
filename=filename,
|
||||
)
|
||||
|
||||
|
||||
def parse_url(url: str, source_type: str) -> ParsedUrl:
|
||||
"""
|
||||
Parse URL to extract package name and version based on source type.
|
||||
|
||||
Args:
|
||||
url: The URL to parse.
|
||||
source_type: The source type (npm, pypi, maven, docker, etc.)
|
||||
|
||||
Returns:
|
||||
ParsedUrl with extracted information.
|
||||
"""
|
||||
parsed = None
|
||||
|
||||
if source_type == "npm":
|
||||
parsed = parse_npm_url(url)
|
||||
elif source_type == "pypi":
|
||||
parsed = parse_pypi_url(url)
|
||||
elif source_type == "maven":
|
||||
parsed = parse_maven_url(url)
|
||||
elif source_type == "docker":
|
||||
parsed = parse_docker_url(url)
|
||||
|
||||
# Fall back to generic parsing if type-specific parsing fails
|
||||
if parsed is None:
|
||||
parsed = parse_generic_url(url)
|
||||
|
||||
return parsed
|
||||
|
||||
|
||||
def get_system_project_name(source_type: str) -> str:
|
||||
"""Get the system project name for a source type."""
|
||||
return SYSTEM_PROJECT_NAMES.get(source_type, "_generic")
|
||||
|
||||
|
||||
def get_system_project_description(source_type: str) -> str:
|
||||
"""Get the system project description for a source type."""
|
||||
return SYSTEM_PROJECT_DESCRIPTIONS.get(
|
||||
source_type, "System cache for artifacts"
|
||||
)
|
||||
@@ -1,5 +1,8 @@
|
||||
from pydantic_settings import BaseSettings
|
||||
from functools import lru_cache
|
||||
from typing import Optional
|
||||
import os
|
||||
import re
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
@@ -56,6 +59,12 @@ class Settings(BaseSettings):
|
||||
# Initial admin user settings
|
||||
admin_password: str = "" # Initial admin password (if empty, uses 'changeme123')
|
||||
|
||||
# Cache settings
|
||||
cache_encryption_key: str = "" # Fernet key for encrypting upstream credentials (auto-generated if empty)
|
||||
# Global cache settings overrides (None = use DB value, True/False = override DB)
|
||||
cache_allow_public_internet: Optional[bool] = None # Override allow_public_internet (air-gap mode)
|
||||
cache_auto_create_system_projects: Optional[bool] = None # Override auto_create_system_projects
|
||||
|
||||
# JWT Authentication settings (optional, for external identity providers)
|
||||
jwt_enabled: bool = False # Enable JWT token validation
|
||||
jwt_secret: str = "" # Secret key for HS256, or leave empty for RS256 with JWKS
|
||||
@@ -88,3 +97,113 @@ class Settings(BaseSettings):
|
||||
@lru_cache()
|
||||
def get_settings() -> Settings:
|
||||
return Settings()
|
||||
|
||||
|
||||
class EnvUpstreamSource:
|
||||
"""Represents an upstream source defined via environment variables."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
url: str,
|
||||
source_type: str = "generic",
|
||||
enabled: bool = True,
|
||||
is_public: bool = True,
|
||||
auth_type: str = "none",
|
||||
username: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
priority: int = 100,
|
||||
):
|
||||
self.name = name
|
||||
self.url = url
|
||||
self.source_type = source_type
|
||||
self.enabled = enabled
|
||||
self.is_public = is_public
|
||||
self.auth_type = auth_type
|
||||
self.username = username
|
||||
self.password = password
|
||||
self.priority = priority
|
||||
self.source = "env" # Mark as env-defined
|
||||
|
||||
|
||||
def parse_upstream_sources_from_env() -> list[EnvUpstreamSource]:
|
||||
"""
|
||||
Parse upstream sources from environment variables.
|
||||
|
||||
Uses double underscore (__) as separator to allow source names with single underscores.
|
||||
Pattern: ORCHARD_UPSTREAM__{NAME}__FIELD
|
||||
|
||||
Example:
|
||||
ORCHARD_UPSTREAM__NPM_PRIVATE__URL=https://npm.corp.com
|
||||
ORCHARD_UPSTREAM__NPM_PRIVATE__TYPE=npm
|
||||
ORCHARD_UPSTREAM__NPM_PRIVATE__ENABLED=true
|
||||
ORCHARD_UPSTREAM__NPM_PRIVATE__AUTH_TYPE=basic
|
||||
ORCHARD_UPSTREAM__NPM_PRIVATE__USERNAME=reader
|
||||
ORCHARD_UPSTREAM__NPM_PRIVATE__PASSWORD=secret
|
||||
|
||||
Returns:
|
||||
List of EnvUpstreamSource objects parsed from environment variables.
|
||||
"""
|
||||
# Pattern: ORCHARD_UPSTREAM__{NAME}__{FIELD}
|
||||
pattern = re.compile(r"^ORCHARD_UPSTREAM__([A-Z0-9_]+)__([A-Z_]+)$", re.IGNORECASE)
|
||||
|
||||
# Collect all env vars matching the pattern, grouped by source name
|
||||
sources_data: dict[str, dict[str, str]] = {}
|
||||
|
||||
for key, value in os.environ.items():
|
||||
match = pattern.match(key)
|
||||
if match:
|
||||
source_name = match.group(1).lower() # Normalize to lowercase
|
||||
field = match.group(2).upper()
|
||||
if source_name not in sources_data:
|
||||
sources_data[source_name] = {}
|
||||
sources_data[source_name][field] = value
|
||||
|
||||
# Build source objects from collected data
|
||||
sources: list[EnvUpstreamSource] = []
|
||||
|
||||
for name, data in sources_data.items():
|
||||
# URL is required
|
||||
url = data.get("URL")
|
||||
if not url:
|
||||
continue # Skip sources without URL
|
||||
|
||||
# Parse boolean fields
|
||||
def parse_bool(val: Optional[str], default: bool) -> bool:
|
||||
if val is None:
|
||||
return default
|
||||
return val.lower() in ("true", "1", "yes", "on")
|
||||
|
||||
# Parse integer fields
|
||||
def parse_int(val: Optional[str], default: int) -> int:
|
||||
if val is None:
|
||||
return default
|
||||
try:
|
||||
return int(val)
|
||||
except ValueError:
|
||||
return default
|
||||
|
||||
source = EnvUpstreamSource(
|
||||
name=name.replace("_", "-"), # Convert underscores to hyphens for readability
|
||||
url=url,
|
||||
source_type=data.get("TYPE", "generic").lower(),
|
||||
enabled=parse_bool(data.get("ENABLED"), True),
|
||||
is_public=parse_bool(data.get("IS_PUBLIC"), True),
|
||||
auth_type=data.get("AUTH_TYPE", "none").lower(),
|
||||
username=data.get("USERNAME"),
|
||||
password=data.get("PASSWORD"),
|
||||
priority=parse_int(data.get("PRIORITY"), 100),
|
||||
)
|
||||
sources.append(source)
|
||||
|
||||
return sources
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def get_env_upstream_sources() -> tuple[EnvUpstreamSource, ...]:
|
||||
"""
|
||||
Get cached list of upstream sources from environment variables.
|
||||
|
||||
Returns a tuple for hashability (required by lru_cache).
|
||||
"""
|
||||
return tuple(parse_upstream_sources_from_env())
|
||||
|
||||
@@ -9,6 +9,7 @@ import hashlib
|
||||
|
||||
from .config import get_settings
|
||||
from .models import Base
|
||||
from .purge_seed_data import should_purge_seed_data, purge_seed_data
|
||||
|
||||
settings = get_settings()
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -80,6 +81,14 @@ def init_db():
|
||||
# Run migrations for schema updates
|
||||
_run_migrations()
|
||||
|
||||
# Purge seed data if requested (for transitioning to production-like environment)
|
||||
if should_purge_seed_data():
|
||||
db = SessionLocal()
|
||||
try:
|
||||
purge_seed_data(db)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def _ensure_migrations_table(conn) -> None:
|
||||
"""Create the migrations tracking table if it doesn't exist."""
|
||||
@@ -429,6 +438,99 @@ def _run_migrations():
|
||||
END $$;
|
||||
""",
|
||||
),
|
||||
Migration(
|
||||
name="016_add_is_system_to_projects",
|
||||
sql="""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (
|
||||
SELECT 1 FROM information_schema.columns
|
||||
WHERE table_name = 'projects' AND column_name = 'is_system'
|
||||
) THEN
|
||||
ALTER TABLE projects ADD COLUMN is_system BOOLEAN NOT NULL DEFAULT FALSE;
|
||||
CREATE INDEX IF NOT EXISTS idx_projects_is_system ON projects(is_system);
|
||||
END IF;
|
||||
END $$;
|
||||
""",
|
||||
),
|
||||
Migration(
|
||||
name="017_create_upstream_sources",
|
||||
sql="""
|
||||
CREATE TABLE IF NOT EXISTS upstream_sources (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
name VARCHAR(255) NOT NULL UNIQUE,
|
||||
source_type VARCHAR(50) NOT NULL DEFAULT 'generic',
|
||||
url VARCHAR(2048) NOT NULL,
|
||||
enabled BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
is_public BOOLEAN NOT NULL DEFAULT TRUE,
|
||||
auth_type VARCHAR(20) NOT NULL DEFAULT 'none',
|
||||
username VARCHAR(255),
|
||||
password_encrypted BYTEA,
|
||||
headers_encrypted BYTEA,
|
||||
priority INTEGER NOT NULL DEFAULT 100,
|
||||
created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(),
|
||||
updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(),
|
||||
CONSTRAINT check_source_type CHECK (
|
||||
source_type IN ('npm', 'pypi', 'maven', 'docker', 'helm', 'nuget', 'deb', 'rpm', 'generic')
|
||||
),
|
||||
CONSTRAINT check_auth_type CHECK (
|
||||
auth_type IN ('none', 'basic', 'bearer', 'api_key')
|
||||
),
|
||||
CONSTRAINT check_priority_positive CHECK (priority > 0)
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_upstream_sources_enabled ON upstream_sources(enabled);
|
||||
CREATE INDEX IF NOT EXISTS idx_upstream_sources_source_type ON upstream_sources(source_type);
|
||||
CREATE INDEX IF NOT EXISTS idx_upstream_sources_is_public ON upstream_sources(is_public);
|
||||
CREATE INDEX IF NOT EXISTS idx_upstream_sources_priority ON upstream_sources(priority);
|
||||
""",
|
||||
),
|
||||
Migration(
|
||||
name="018_create_cache_settings",
|
||||
sql="""
|
||||
CREATE TABLE IF NOT EXISTS cache_settings (
|
||||
id INTEGER PRIMARY KEY DEFAULT 1,
|
||||
allow_public_internet BOOLEAN NOT NULL DEFAULT TRUE,
|
||||
auto_create_system_projects BOOLEAN NOT NULL DEFAULT TRUE,
|
||||
created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(),
|
||||
updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(),
|
||||
CONSTRAINT check_cache_settings_singleton CHECK (id = 1)
|
||||
);
|
||||
INSERT INTO cache_settings (id, allow_public_internet, auto_create_system_projects)
|
||||
VALUES (1, TRUE, TRUE)
|
||||
ON CONFLICT (id) DO NOTHING;
|
||||
""",
|
||||
),
|
||||
Migration(
|
||||
name="019_create_cached_urls",
|
||||
sql="""
|
||||
CREATE TABLE IF NOT EXISTS cached_urls (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
url VARCHAR(4096) NOT NULL,
|
||||
url_hash VARCHAR(64) NOT NULL UNIQUE,
|
||||
artifact_id VARCHAR(64) NOT NULL REFERENCES artifacts(id),
|
||||
source_id UUID REFERENCES upstream_sources(id) ON DELETE SET NULL,
|
||||
fetched_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
|
||||
response_headers JSONB DEFAULT '{}',
|
||||
created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW()
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_cached_urls_url_hash ON cached_urls(url_hash);
|
||||
CREATE INDEX IF NOT EXISTS idx_cached_urls_artifact_id ON cached_urls(artifact_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_cached_urls_source_id ON cached_urls(source_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_cached_urls_fetched_at ON cached_urls(fetched_at);
|
||||
""",
|
||||
),
|
||||
Migration(
|
||||
name="020_seed_default_upstream_sources",
|
||||
sql="""
|
||||
INSERT INTO upstream_sources (id, name, source_type, url, enabled, is_public, auth_type, priority)
|
||||
VALUES
|
||||
(gen_random_uuid(), 'npm-public', 'npm', 'https://registry.npmjs.org', FALSE, TRUE, 'none', 100),
|
||||
(gen_random_uuid(), 'pypi-public', 'pypi', 'https://pypi.org/simple', FALSE, TRUE, 'none', 100),
|
||||
(gen_random_uuid(), 'maven-central', 'maven', 'https://repo1.maven.org/maven2', FALSE, TRUE, 'none', 100),
|
||||
(gen_random_uuid(), 'docker-hub', 'docker', 'https://registry-1.docker.io', FALSE, TRUE, 'none', 100)
|
||||
ON CONFLICT (name) DO NOTHING;
|
||||
""",
|
||||
),
|
||||
]
|
||||
|
||||
with engine.connect() as conn:
|
||||
|
||||
160
backend/app/encryption.py
Normal file
160
backend/app/encryption.py
Normal file
@@ -0,0 +1,160 @@
|
||||
"""
|
||||
Encryption utilities for sensitive data storage.
|
||||
|
||||
Uses Fernet symmetric encryption for credentials like upstream passwords.
|
||||
The encryption key is sourced from ORCHARD_CACHE_ENCRYPTION_KEY environment variable.
|
||||
If not set, a random key is generated on startup (with a warning).
|
||||
"""
|
||||
|
||||
import base64
|
||||
import logging
|
||||
import os
|
||||
import secrets
|
||||
from functools import lru_cache
|
||||
from typing import Optional
|
||||
|
||||
from cryptography.fernet import Fernet, InvalidToken
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Module-level storage for auto-generated key (only used if env var not set)
|
||||
_generated_key: Optional[bytes] = None
|
||||
|
||||
|
||||
def _get_key_from_env() -> Optional[bytes]:
|
||||
"""Get encryption key from environment variable."""
|
||||
key_str = os.environ.get("ORCHARD_CACHE_ENCRYPTION_KEY", "")
|
||||
if not key_str:
|
||||
return None
|
||||
|
||||
# Support both raw base64 and url-safe base64 formats
|
||||
try:
|
||||
# Try to decode as-is (Fernet keys are url-safe base64)
|
||||
key_bytes = key_str.encode("utf-8")
|
||||
# Validate it's a valid Fernet key by trying to create a Fernet instance
|
||||
Fernet(key_bytes)
|
||||
return key_bytes
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Try base64 decoding if it's a raw 32-byte key encoded as base64
|
||||
try:
|
||||
decoded = base64.urlsafe_b64decode(key_str)
|
||||
if len(decoded) == 32:
|
||||
# Re-encode as url-safe base64 for Fernet
|
||||
key_bytes = base64.urlsafe_b64encode(decoded)
|
||||
Fernet(key_bytes)
|
||||
return key_bytes
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
logger.error(
|
||||
"ORCHARD_CACHE_ENCRYPTION_KEY is set but invalid. "
|
||||
"Must be a valid Fernet key (32 bytes, url-safe base64 encoded). "
|
||||
"Generate one with: python -c \"from cryptography.fernet import Fernet; print(Fernet.generate_key().decode())\""
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def get_encryption_key() -> bytes:
|
||||
"""
|
||||
Get the Fernet encryption key.
|
||||
|
||||
Returns the key from ORCHARD_CACHE_ENCRYPTION_KEY if set and valid,
|
||||
otherwise generates a random key (with a warning logged).
|
||||
|
||||
The generated key is cached for the lifetime of the process.
|
||||
"""
|
||||
global _generated_key
|
||||
|
||||
# Try to get from environment
|
||||
env_key = _get_key_from_env()
|
||||
if env_key:
|
||||
return env_key
|
||||
|
||||
# Generate a new key if needed
|
||||
if _generated_key is None:
|
||||
_generated_key = Fernet.generate_key()
|
||||
logger.warning(
|
||||
"ORCHARD_CACHE_ENCRYPTION_KEY not set - using auto-generated key. "
|
||||
"Encrypted credentials will be lost on restart! "
|
||||
"Set ORCHARD_CACHE_ENCRYPTION_KEY for persistent encryption. "
|
||||
"Generate a key with: python -c \"from cryptography.fernet import Fernet; print(Fernet.generate_key().decode())\""
|
||||
)
|
||||
|
||||
return _generated_key
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def _get_fernet() -> Fernet:
|
||||
"""Get a cached Fernet instance."""
|
||||
return Fernet(get_encryption_key())
|
||||
|
||||
|
||||
def encrypt_value(plaintext: str) -> bytes:
|
||||
"""
|
||||
Encrypt a string value using Fernet.
|
||||
|
||||
Args:
|
||||
plaintext: The string to encrypt
|
||||
|
||||
Returns:
|
||||
Encrypted bytes (includes Fernet token with timestamp)
|
||||
"""
|
||||
if not plaintext:
|
||||
raise ValueError("Cannot encrypt empty value")
|
||||
|
||||
fernet = _get_fernet()
|
||||
return fernet.encrypt(plaintext.encode("utf-8"))
|
||||
|
||||
|
||||
def decrypt_value(ciphertext: bytes) -> str:
|
||||
"""
|
||||
Decrypt a Fernet-encrypted value.
|
||||
|
||||
Args:
|
||||
ciphertext: The encrypted bytes
|
||||
|
||||
Returns:
|
||||
Decrypted string
|
||||
|
||||
Raises:
|
||||
InvalidToken: If decryption fails (wrong key or corrupted data)
|
||||
"""
|
||||
if not ciphertext:
|
||||
raise ValueError("Cannot decrypt empty value")
|
||||
|
||||
fernet = _get_fernet()
|
||||
return fernet.decrypt(ciphertext).decode("utf-8")
|
||||
|
||||
|
||||
def can_decrypt(ciphertext: bytes) -> bool:
|
||||
"""
|
||||
Check if a value can be decrypted with the current key.
|
||||
|
||||
Useful for checking if credentials are still valid after key rotation.
|
||||
|
||||
Args:
|
||||
ciphertext: The encrypted bytes
|
||||
|
||||
Returns:
|
||||
True if decryption succeeds, False otherwise
|
||||
"""
|
||||
if not ciphertext:
|
||||
return False
|
||||
|
||||
try:
|
||||
decrypt_value(ciphertext)
|
||||
return True
|
||||
except (InvalidToken, ValueError):
|
||||
return False
|
||||
|
||||
|
||||
def generate_key() -> str:
|
||||
"""
|
||||
Generate a new Fernet encryption key.
|
||||
|
||||
Returns:
|
||||
A valid Fernet key as a string (url-safe base64 encoded)
|
||||
"""
|
||||
return Fernet.generate_key().decode("utf-8")
|
||||
@@ -12,6 +12,7 @@ from sqlalchemy import (
|
||||
Index,
|
||||
JSON,
|
||||
ARRAY,
|
||||
LargeBinary,
|
||||
)
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import relationship, declarative_base
|
||||
@@ -27,6 +28,7 @@ class Project(Base):
|
||||
name = Column(String(255), unique=True, nullable=False)
|
||||
description = Column(Text)
|
||||
is_public = Column(Boolean, default=True)
|
||||
is_system = Column(Boolean, default=False, nullable=False)
|
||||
created_at = Column(DateTime(timezone=True), default=datetime.utcnow)
|
||||
updated_at = Column(
|
||||
DateTime(timezone=True), default=datetime.utcnow, onupdate=datetime.utcnow
|
||||
@@ -46,6 +48,7 @@ class Project(Base):
|
||||
Index("idx_projects_name", "name"),
|
||||
Index("idx_projects_created_by", "created_by"),
|
||||
Index("idx_projects_team_id", "team_id"),
|
||||
Index("idx_projects_is_system", "is_system"),
|
||||
)
|
||||
|
||||
|
||||
@@ -637,3 +640,169 @@ class TeamMembership(Base):
|
||||
name="check_team_role",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Upstream Caching Models
|
||||
# =============================================================================
|
||||
|
||||
# Valid source types for upstream registries
|
||||
SOURCE_TYPES = ["npm", "pypi", "maven", "docker", "helm", "nuget", "deb", "rpm", "generic"]
|
||||
|
||||
# Valid authentication types
|
||||
AUTH_TYPES = ["none", "basic", "bearer", "api_key"]
|
||||
|
||||
|
||||
class UpstreamSource(Base):
|
||||
"""Configuration for an upstream artifact registry.
|
||||
|
||||
Stores connection details and authentication for upstream registries
|
||||
like npm, PyPI, Maven Central, or private Artifactory instances.
|
||||
"""
|
||||
|
||||
__tablename__ = "upstream_sources"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
name = Column(String(255), unique=True, nullable=False)
|
||||
source_type = Column(String(50), default="generic", nullable=False)
|
||||
url = Column(String(2048), nullable=False)
|
||||
enabled = Column(Boolean, default=False, nullable=False)
|
||||
is_public = Column(Boolean, default=True, nullable=False)
|
||||
auth_type = Column(String(20), default="none", nullable=False)
|
||||
username = Column(String(255))
|
||||
password_encrypted = Column(LargeBinary)
|
||||
headers_encrypted = Column(LargeBinary)
|
||||
priority = Column(Integer, default=100, nullable=False)
|
||||
created_at = Column(DateTime(timezone=True), default=datetime.utcnow)
|
||||
updated_at = Column(
|
||||
DateTime(timezone=True), default=datetime.utcnow, onupdate=datetime.utcnow
|
||||
)
|
||||
|
||||
# Relationships
|
||||
cached_urls = relationship("CachedUrl", back_populates="source")
|
||||
|
||||
__table_args__ = (
|
||||
Index("idx_upstream_sources_enabled", "enabled"),
|
||||
Index("idx_upstream_sources_source_type", "source_type"),
|
||||
Index("idx_upstream_sources_is_public", "is_public"),
|
||||
Index("idx_upstream_sources_priority", "priority"),
|
||||
CheckConstraint(
|
||||
"source_type IN ('npm', 'pypi', 'maven', 'docker', 'helm', 'nuget', 'deb', 'rpm', 'generic')",
|
||||
name="check_source_type",
|
||||
),
|
||||
CheckConstraint(
|
||||
"auth_type IN ('none', 'basic', 'bearer', 'api_key')",
|
||||
name="check_auth_type",
|
||||
),
|
||||
CheckConstraint("priority > 0", name="check_priority_positive"),
|
||||
)
|
||||
|
||||
def set_password(self, password: str) -> None:
|
||||
"""Encrypt and store a password/token."""
|
||||
from .encryption import encrypt_value
|
||||
|
||||
if password:
|
||||
self.password_encrypted = encrypt_value(password)
|
||||
else:
|
||||
self.password_encrypted = None
|
||||
|
||||
def get_password(self) -> str | None:
|
||||
"""Decrypt and return the stored password/token."""
|
||||
from .encryption import decrypt_value
|
||||
|
||||
if self.password_encrypted:
|
||||
try:
|
||||
return decrypt_value(self.password_encrypted)
|
||||
except Exception:
|
||||
return None
|
||||
return None
|
||||
|
||||
def has_password(self) -> bool:
|
||||
"""Check if a password/token is stored."""
|
||||
return self.password_encrypted is not None
|
||||
|
||||
def set_headers(self, headers: dict) -> None:
|
||||
"""Encrypt and store custom headers as JSON."""
|
||||
from .encryption import encrypt_value
|
||||
import json
|
||||
|
||||
if headers:
|
||||
self.headers_encrypted = encrypt_value(json.dumps(headers))
|
||||
else:
|
||||
self.headers_encrypted = None
|
||||
|
||||
def get_headers(self) -> dict | None:
|
||||
"""Decrypt and return custom headers."""
|
||||
from .encryption import decrypt_value
|
||||
import json
|
||||
|
||||
if self.headers_encrypted:
|
||||
try:
|
||||
return json.loads(decrypt_value(self.headers_encrypted))
|
||||
except Exception:
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
class CacheSettings(Base):
|
||||
"""Global cache settings (singleton table).
|
||||
|
||||
Controls behavior of the upstream caching system including air-gap mode.
|
||||
"""
|
||||
|
||||
__tablename__ = "cache_settings"
|
||||
|
||||
id = Column(Integer, primary_key=True, default=1)
|
||||
allow_public_internet = Column(Boolean, default=True, nullable=False)
|
||||
auto_create_system_projects = Column(Boolean, default=True, nullable=False)
|
||||
created_at = Column(DateTime(timezone=True), default=datetime.utcnow)
|
||||
updated_at = Column(
|
||||
DateTime(timezone=True), default=datetime.utcnow, onupdate=datetime.utcnow
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
CheckConstraint("id = 1", name="check_cache_settings_singleton"),
|
||||
)
|
||||
|
||||
|
||||
class CachedUrl(Base):
|
||||
"""Tracks URL to artifact mappings for provenance.
|
||||
|
||||
Records which URLs have been cached and maps them to their stored artifacts.
|
||||
Enables "is this URL already cached?" lookups and audit trails.
|
||||
"""
|
||||
|
||||
__tablename__ = "cached_urls"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
url = Column(String(4096), nullable=False)
|
||||
url_hash = Column(String(64), unique=True, nullable=False)
|
||||
artifact_id = Column(
|
||||
String(64), ForeignKey("artifacts.id"), nullable=False
|
||||
)
|
||||
source_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("upstream_sources.id", ondelete="SET NULL"),
|
||||
)
|
||||
fetched_at = Column(DateTime(timezone=True), default=datetime.utcnow, nullable=False)
|
||||
response_headers = Column(JSON, default=dict)
|
||||
created_at = Column(DateTime(timezone=True), default=datetime.utcnow)
|
||||
|
||||
# Relationships
|
||||
artifact = relationship("Artifact")
|
||||
source = relationship("UpstreamSource", back_populates="cached_urls")
|
||||
|
||||
__table_args__ = (
|
||||
Index("idx_cached_urls_url_hash", "url_hash"),
|
||||
Index("idx_cached_urls_artifact_id", "artifact_id"),
|
||||
Index("idx_cached_urls_source_id", "source_id"),
|
||||
Index("idx_cached_urls_fetched_at", "fetched_at"),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def compute_url_hash(url: str) -> str:
|
||||
"""Compute SHA256 hash of a URL for fast lookups."""
|
||||
import hashlib
|
||||
return hashlib.sha256(url.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
|
||||
211
backend/app/purge_seed_data.py
Normal file
211
backend/app/purge_seed_data.py
Normal file
@@ -0,0 +1,211 @@
|
||||
"""
|
||||
Purge seed/demo data from the database.
|
||||
|
||||
This is used when transitioning an environment from dev/test to production-like.
|
||||
Triggered by setting ORCHARD_PURGE_SEED_DATA=true environment variable.
|
||||
"""
|
||||
import logging
|
||||
import os
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from .models import (
|
||||
Project,
|
||||
Package,
|
||||
Artifact,
|
||||
Tag,
|
||||
Upload,
|
||||
PackageVersion,
|
||||
ArtifactDependency,
|
||||
Team,
|
||||
TeamMembership,
|
||||
User,
|
||||
AccessPermission,
|
||||
)
|
||||
from .storage import get_storage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Seed data identifiers (from seed.py)
|
||||
SEED_PROJECT_NAMES = [
|
||||
"frontend-libs",
|
||||
"backend-services",
|
||||
"mobile-apps",
|
||||
"internal-tools",
|
||||
]
|
||||
|
||||
SEED_TEAM_SLUG = "demo-team"
|
||||
|
||||
SEED_USERNAMES = [
|
||||
"alice",
|
||||
"bob",
|
||||
"charlie",
|
||||
"diana",
|
||||
"eve",
|
||||
"frank",
|
||||
]
|
||||
|
||||
|
||||
def should_purge_seed_data() -> bool:
|
||||
"""Check if seed data should be purged based on environment variable."""
|
||||
return os.environ.get("ORCHARD_PURGE_SEED_DATA", "").lower() == "true"
|
||||
|
||||
|
||||
def purge_seed_data(db: Session) -> dict:
|
||||
"""
|
||||
Purge all seed/demo data from the database.
|
||||
|
||||
Returns a dict with counts of deleted items.
|
||||
"""
|
||||
logger.warning("PURGING SEED DATA - This will delete demo projects, users, and teams")
|
||||
|
||||
results = {
|
||||
"dependencies_deleted": 0,
|
||||
"tags_deleted": 0,
|
||||
"versions_deleted": 0,
|
||||
"uploads_deleted": 0,
|
||||
"artifacts_deleted": 0,
|
||||
"packages_deleted": 0,
|
||||
"projects_deleted": 0,
|
||||
"permissions_deleted": 0,
|
||||
"team_memberships_deleted": 0,
|
||||
"users_deleted": 0,
|
||||
"teams_deleted": 0,
|
||||
"s3_objects_deleted": 0,
|
||||
}
|
||||
|
||||
storage = get_storage()
|
||||
|
||||
# Find seed projects
|
||||
seed_projects = db.query(Project).filter(Project.name.in_(SEED_PROJECT_NAMES)).all()
|
||||
seed_project_ids = [p.id for p in seed_projects]
|
||||
|
||||
if not seed_projects:
|
||||
logger.info("No seed projects found, nothing to purge")
|
||||
return results
|
||||
|
||||
logger.info(f"Found {len(seed_projects)} seed projects to purge")
|
||||
|
||||
# Find packages in seed projects
|
||||
seed_packages = db.query(Package).filter(Package.project_id.in_(seed_project_ids)).all()
|
||||
seed_package_ids = [p.id for p in seed_packages]
|
||||
|
||||
# Find artifacts in seed packages (via uploads)
|
||||
seed_uploads = db.query(Upload).filter(Upload.package_id.in_(seed_package_ids)).all()
|
||||
seed_artifact_ids = list(set(u.artifact_id for u in seed_uploads))
|
||||
|
||||
# Delete in order (respecting foreign keys)
|
||||
|
||||
# 1. Delete artifact dependencies
|
||||
if seed_artifact_ids:
|
||||
count = db.query(ArtifactDependency).filter(
|
||||
ArtifactDependency.artifact_id.in_(seed_artifact_ids)
|
||||
).delete(synchronize_session=False)
|
||||
results["dependencies_deleted"] = count
|
||||
logger.info(f"Deleted {count} artifact dependencies")
|
||||
|
||||
# 2. Delete tags
|
||||
if seed_package_ids:
|
||||
count = db.query(Tag).filter(Tag.package_id.in_(seed_package_ids)).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
results["tags_deleted"] = count
|
||||
logger.info(f"Deleted {count} tags")
|
||||
|
||||
# 3. Delete package versions
|
||||
if seed_package_ids:
|
||||
count = db.query(PackageVersion).filter(
|
||||
PackageVersion.package_id.in_(seed_package_ids)
|
||||
).delete(synchronize_session=False)
|
||||
results["versions_deleted"] = count
|
||||
logger.info(f"Deleted {count} package versions")
|
||||
|
||||
# 4. Delete uploads
|
||||
if seed_package_ids:
|
||||
count = db.query(Upload).filter(Upload.package_id.in_(seed_package_ids)).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
results["uploads_deleted"] = count
|
||||
logger.info(f"Deleted {count} uploads")
|
||||
|
||||
# 5. Delete S3 objects for seed artifacts
|
||||
if seed_artifact_ids:
|
||||
seed_artifacts = db.query(Artifact).filter(Artifact.id.in_(seed_artifact_ids)).all()
|
||||
for artifact in seed_artifacts:
|
||||
if artifact.s3_key:
|
||||
try:
|
||||
storage.client.delete_object(Bucket=storage.bucket, Key=artifact.s3_key)
|
||||
results["s3_objects_deleted"] += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete S3 object {artifact.s3_key}: {e}")
|
||||
logger.info(f"Deleted {results['s3_objects_deleted']} S3 objects")
|
||||
|
||||
# 6. Delete artifacts (only those with ref_count that would be 0 after our deletions)
|
||||
# Since we deleted all tags/versions pointing to these artifacts, we can delete them
|
||||
if seed_artifact_ids:
|
||||
count = db.query(Artifact).filter(Artifact.id.in_(seed_artifact_ids)).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
results["artifacts_deleted"] = count
|
||||
logger.info(f"Deleted {count} artifacts")
|
||||
|
||||
# 7. Delete packages
|
||||
if seed_package_ids:
|
||||
count = db.query(Package).filter(Package.id.in_(seed_package_ids)).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
results["packages_deleted"] = count
|
||||
logger.info(f"Deleted {count} packages")
|
||||
|
||||
# 8. Delete access permissions for seed projects
|
||||
if seed_project_ids:
|
||||
count = db.query(AccessPermission).filter(
|
||||
AccessPermission.project_id.in_(seed_project_ids)
|
||||
).delete(synchronize_session=False)
|
||||
results["permissions_deleted"] = count
|
||||
logger.info(f"Deleted {count} access permissions")
|
||||
|
||||
# 9. Delete seed projects
|
||||
count = db.query(Project).filter(Project.name.in_(SEED_PROJECT_NAMES)).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
results["projects_deleted"] = count
|
||||
logger.info(f"Deleted {count} projects")
|
||||
|
||||
# 10. Find and delete seed team
|
||||
seed_team = db.query(Team).filter(Team.slug == SEED_TEAM_SLUG).first()
|
||||
if seed_team:
|
||||
# Delete team memberships first
|
||||
count = db.query(TeamMembership).filter(
|
||||
TeamMembership.team_id == seed_team.id
|
||||
).delete(synchronize_session=False)
|
||||
results["team_memberships_deleted"] = count
|
||||
logger.info(f"Deleted {count} team memberships")
|
||||
|
||||
# Delete the team
|
||||
db.delete(seed_team)
|
||||
results["teams_deleted"] = 1
|
||||
logger.info(f"Deleted team: {SEED_TEAM_SLUG}")
|
||||
|
||||
# 11. Delete seed users (but NOT admin)
|
||||
seed_users = db.query(User).filter(User.username.in_(SEED_USERNAMES)).all()
|
||||
for user in seed_users:
|
||||
# Delete any remaining team memberships for this user
|
||||
db.query(TeamMembership).filter(TeamMembership.user_id == user.id).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
# Delete any access permissions for this user
|
||||
db.query(AccessPermission).filter(AccessPermission.user_id == user.id).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
db.delete(user)
|
||||
results["users_deleted"] += 1
|
||||
|
||||
if results["users_deleted"] > 0:
|
||||
logger.info(f"Deleted {results['users_deleted']} seed users")
|
||||
|
||||
db.commit()
|
||||
|
||||
logger.warning("SEED DATA PURGE COMPLETE")
|
||||
logger.info(f"Purge results: {results}")
|
||||
|
||||
return results
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1196,3 +1196,246 @@ class TeamMemberResponse(BaseModel):
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Upstream Caching Schemas
|
||||
# =============================================================================
|
||||
|
||||
# Valid source types
|
||||
SOURCE_TYPES = ["npm", "pypi", "maven", "docker", "helm", "nuget", "deb", "rpm", "generic"]
|
||||
|
||||
# Valid auth types
|
||||
AUTH_TYPES = ["none", "basic", "bearer", "api_key"]
|
||||
|
||||
|
||||
class UpstreamSourceCreate(BaseModel):
|
||||
"""Create a new upstream source"""
|
||||
name: str
|
||||
source_type: str = "generic"
|
||||
url: str
|
||||
enabled: bool = False
|
||||
is_public: bool = True
|
||||
auth_type: str = "none"
|
||||
username: Optional[str] = None
|
||||
password: Optional[str] = None # Write-only
|
||||
headers: Optional[dict] = None # Write-only, custom headers
|
||||
priority: int = 100
|
||||
|
||||
@field_validator('name')
|
||||
@classmethod
|
||||
def validate_name(cls, v: str) -> str:
|
||||
v = v.strip()
|
||||
if not v:
|
||||
raise ValueError("name cannot be empty")
|
||||
if len(v) > 255:
|
||||
raise ValueError("name must be 255 characters or less")
|
||||
return v
|
||||
|
||||
@field_validator('source_type')
|
||||
@classmethod
|
||||
def validate_source_type(cls, v: str) -> str:
|
||||
if v not in SOURCE_TYPES:
|
||||
raise ValueError(f"source_type must be one of: {', '.join(SOURCE_TYPES)}")
|
||||
return v
|
||||
|
||||
@field_validator('url')
|
||||
@classmethod
|
||||
def validate_url(cls, v: str) -> str:
|
||||
v = v.strip()
|
||||
if not v:
|
||||
raise ValueError("url cannot be empty")
|
||||
if not (v.startswith('http://') or v.startswith('https://')):
|
||||
raise ValueError("url must start with http:// or https://")
|
||||
if len(v) > 2048:
|
||||
raise ValueError("url must be 2048 characters or less")
|
||||
return v
|
||||
|
||||
@field_validator('auth_type')
|
||||
@classmethod
|
||||
def validate_auth_type(cls, v: str) -> str:
|
||||
if v not in AUTH_TYPES:
|
||||
raise ValueError(f"auth_type must be one of: {', '.join(AUTH_TYPES)}")
|
||||
return v
|
||||
|
||||
@field_validator('priority')
|
||||
@classmethod
|
||||
def validate_priority(cls, v: int) -> int:
|
||||
if v <= 0:
|
||||
raise ValueError("priority must be greater than 0")
|
||||
return v
|
||||
|
||||
|
||||
class UpstreamSourceUpdate(BaseModel):
|
||||
"""Update an upstream source (partial)"""
|
||||
name: Optional[str] = None
|
||||
source_type: Optional[str] = None
|
||||
url: Optional[str] = None
|
||||
enabled: Optional[bool] = None
|
||||
is_public: Optional[bool] = None
|
||||
auth_type: Optional[str] = None
|
||||
username: Optional[str] = None
|
||||
password: Optional[str] = None # Write-only, None = keep existing, empty string = clear
|
||||
headers: Optional[dict] = None # Write-only
|
||||
priority: Optional[int] = None
|
||||
|
||||
@field_validator('name')
|
||||
@classmethod
|
||||
def validate_name(cls, v: Optional[str]) -> Optional[str]:
|
||||
if v is not None:
|
||||
v = v.strip()
|
||||
if not v:
|
||||
raise ValueError("name cannot be empty")
|
||||
if len(v) > 255:
|
||||
raise ValueError("name must be 255 characters or less")
|
||||
return v
|
||||
|
||||
@field_validator('source_type')
|
||||
@classmethod
|
||||
def validate_source_type(cls, v: Optional[str]) -> Optional[str]:
|
||||
if v is not None and v not in SOURCE_TYPES:
|
||||
raise ValueError(f"source_type must be one of: {', '.join(SOURCE_TYPES)}")
|
||||
return v
|
||||
|
||||
@field_validator('url')
|
||||
@classmethod
|
||||
def validate_url(cls, v: Optional[str]) -> Optional[str]:
|
||||
if v is not None:
|
||||
v = v.strip()
|
||||
if not v:
|
||||
raise ValueError("url cannot be empty")
|
||||
if not (v.startswith('http://') or v.startswith('https://')):
|
||||
raise ValueError("url must start with http:// or https://")
|
||||
if len(v) > 2048:
|
||||
raise ValueError("url must be 2048 characters or less")
|
||||
return v
|
||||
|
||||
@field_validator('auth_type')
|
||||
@classmethod
|
||||
def validate_auth_type(cls, v: Optional[str]) -> Optional[str]:
|
||||
if v is not None and v not in AUTH_TYPES:
|
||||
raise ValueError(f"auth_type must be one of: {', '.join(AUTH_TYPES)}")
|
||||
return v
|
||||
|
||||
@field_validator('priority')
|
||||
@classmethod
|
||||
def validate_priority(cls, v: Optional[int]) -> Optional[int]:
|
||||
if v is not None and v <= 0:
|
||||
raise ValueError("priority must be greater than 0")
|
||||
return v
|
||||
|
||||
|
||||
class UpstreamSourceResponse(BaseModel):
|
||||
"""Upstream source response (credentials never included)"""
|
||||
id: UUID
|
||||
name: str
|
||||
source_type: str
|
||||
url: str
|
||||
enabled: bool
|
||||
is_public: bool
|
||||
auth_type: str
|
||||
username: Optional[str]
|
||||
has_password: bool # True if password is set
|
||||
has_headers: bool # True if custom headers are set
|
||||
priority: int
|
||||
source: str = "database" # "database" or "env" (env = defined via environment variables)
|
||||
created_at: Optional[datetime] = None # May be None for legacy/env data
|
||||
updated_at: Optional[datetime] = None # May be None for legacy/env data
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class CacheSettingsResponse(BaseModel):
|
||||
"""Global cache settings response"""
|
||||
allow_public_internet: bool
|
||||
auto_create_system_projects: bool
|
||||
allow_public_internet_env_override: Optional[bool] = None # Set if overridden by env var
|
||||
auto_create_system_projects_env_override: Optional[bool] = None # Set if overridden by env var
|
||||
created_at: Optional[datetime] = None # May be None for legacy data
|
||||
updated_at: Optional[datetime] = None # May be None for legacy data
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class CacheSettingsUpdate(BaseModel):
|
||||
"""Update cache settings (partial)"""
|
||||
allow_public_internet: Optional[bool] = None
|
||||
auto_create_system_projects: Optional[bool] = None
|
||||
|
||||
|
||||
class CachedUrlResponse(BaseModel):
|
||||
"""Cached URL response"""
|
||||
id: UUID
|
||||
url: str
|
||||
url_hash: str
|
||||
artifact_id: str
|
||||
source_id: Optional[UUID]
|
||||
source_name: Optional[str] = None # Populated from join
|
||||
fetched_at: datetime
|
||||
created_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class CacheRequest(BaseModel):
|
||||
"""Request to cache an artifact from an upstream URL"""
|
||||
url: str
|
||||
source_type: str
|
||||
package_name: Optional[str] = None # Auto-derived from URL if not provided
|
||||
tag: Optional[str] = None # Auto-derived from URL if not provided
|
||||
user_project: Optional[str] = None # Cross-reference to user project
|
||||
user_package: Optional[str] = None
|
||||
user_tag: Optional[str] = None
|
||||
expected_hash: Optional[str] = None # Verify downloaded content
|
||||
|
||||
@field_validator('url')
|
||||
@classmethod
|
||||
def validate_url(cls, v: str) -> str:
|
||||
v = v.strip()
|
||||
if not v:
|
||||
raise ValueError("url cannot be empty")
|
||||
if not (v.startswith('http://') or v.startswith('https://')):
|
||||
raise ValueError("url must start with http:// or https://")
|
||||
if len(v) > 4096:
|
||||
raise ValueError("url must be 4096 characters or less")
|
||||
return v
|
||||
|
||||
@field_validator('source_type')
|
||||
@classmethod
|
||||
def validate_source_type(cls, v: str) -> str:
|
||||
if v not in SOURCE_TYPES:
|
||||
raise ValueError(f"source_type must be one of: {', '.join(SOURCE_TYPES)}")
|
||||
return v
|
||||
|
||||
@field_validator('expected_hash')
|
||||
@classmethod
|
||||
def validate_expected_hash(cls, v: Optional[str]) -> Optional[str]:
|
||||
if v is not None:
|
||||
v = v.strip().lower()
|
||||
# Remove sha256: prefix if present
|
||||
if v.startswith('sha256:'):
|
||||
v = v[7:]
|
||||
# Validate hex format
|
||||
if len(v) != 64 or not all(c in '0123456789abcdef' for c in v):
|
||||
raise ValueError("expected_hash must be a 64-character hex string (SHA256)")
|
||||
return v
|
||||
|
||||
|
||||
class CacheResponse(BaseModel):
|
||||
"""Response from caching an artifact"""
|
||||
artifact_id: str
|
||||
sha256: str
|
||||
size: int
|
||||
content_type: Optional[str]
|
||||
already_cached: bool
|
||||
source_url: str
|
||||
source_name: Optional[str]
|
||||
system_project: str
|
||||
system_package: str
|
||||
system_tag: Optional[str]
|
||||
user_reference: Optional[str] = None # e.g., "my-app/npm-deps:lodash-4.17.21"
|
||||
|
||||
|
||||
|
||||
|
||||
586
backend/app/upstream.py
Normal file
586
backend/app/upstream.py
Normal file
@@ -0,0 +1,586 @@
|
||||
"""
|
||||
HTTP client for fetching artifacts from upstream sources.
|
||||
|
||||
Provides streaming downloads with SHA256 computation, authentication support,
|
||||
and automatic source matching based on URL prefixes.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
import tempfile
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import BinaryIO, Optional, TYPE_CHECKING
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .models import CacheSettings, UpstreamSource
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UpstreamError(Exception):
|
||||
"""Base exception for upstream client errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class UpstreamConnectionError(UpstreamError):
|
||||
"""Connection to upstream failed (network error, DNS, etc.)."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class UpstreamTimeoutError(UpstreamError):
|
||||
"""Request to upstream timed out."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class UpstreamHTTPError(UpstreamError):
|
||||
"""Upstream returned an HTTP error response."""
|
||||
|
||||
def __init__(self, message: str, status_code: int, response_headers: dict = None):
|
||||
super().__init__(message)
|
||||
self.status_code = status_code
|
||||
self.response_headers = response_headers or {}
|
||||
|
||||
|
||||
class UpstreamSSLError(UpstreamError):
|
||||
"""SSL/TLS error when connecting to upstream."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class AirGapError(UpstreamError):
|
||||
"""Request blocked due to air-gap mode."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class FileSizeExceededError(UpstreamError):
|
||||
"""File size exceeds the maximum allowed."""
|
||||
|
||||
def __init__(self, message: str, content_length: int, max_size: int):
|
||||
super().__init__(message)
|
||||
self.content_length = content_length
|
||||
self.max_size = max_size
|
||||
|
||||
|
||||
class SourceNotFoundError(UpstreamError):
|
||||
"""No matching upstream source found for URL."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class SourceDisabledError(UpstreamError):
|
||||
"""The matching upstream source is disabled."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class FetchResult:
|
||||
"""Result of fetching an artifact from upstream."""
|
||||
|
||||
content: BinaryIO # File-like object with content
|
||||
sha256: str # SHA256 hash of content
|
||||
size: int # Size in bytes
|
||||
content_type: Optional[str] # Content-Type header
|
||||
response_headers: dict # All response headers for provenance
|
||||
source_name: Optional[str] = None # Name of matched upstream source
|
||||
temp_path: Optional[Path] = None # Path to temp file (for cleanup)
|
||||
|
||||
def close(self):
|
||||
"""Close and clean up resources."""
|
||||
if self.content:
|
||||
try:
|
||||
self.content.close()
|
||||
except Exception:
|
||||
pass
|
||||
if self.temp_path and self.temp_path.exists():
|
||||
try:
|
||||
self.temp_path.unlink()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class UpstreamClientConfig:
|
||||
"""Configuration for the upstream client."""
|
||||
|
||||
connect_timeout: float = 30.0 # Connection timeout in seconds
|
||||
read_timeout: float = 300.0 # Read timeout in seconds (5 minutes for large files)
|
||||
max_retries: int = 3 # Maximum number of retry attempts
|
||||
retry_backoff_base: float = 1.0 # Base delay for exponential backoff
|
||||
retry_backoff_max: float = 30.0 # Maximum delay between retries
|
||||
follow_redirects: bool = True # Whether to follow redirects
|
||||
max_redirects: int = 5 # Maximum number of redirects to follow
|
||||
max_file_size: Optional[int] = None # Maximum file size (None = unlimited)
|
||||
verify_ssl: bool = True # Verify SSL certificates
|
||||
user_agent: str = "Orchard-UpstreamClient/1.0"
|
||||
|
||||
|
||||
class UpstreamClient:
|
||||
"""
|
||||
HTTP client for fetching artifacts from upstream sources.
|
||||
|
||||
Supports streaming downloads, multiple authentication methods,
|
||||
automatic source matching, and air-gap mode enforcement.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
sources: list[UpstreamSource] = None,
|
||||
cache_settings: CacheSettings = None,
|
||||
config: UpstreamClientConfig = None,
|
||||
):
|
||||
"""
|
||||
Initialize the upstream client.
|
||||
|
||||
Args:
|
||||
sources: List of upstream sources for URL matching and auth.
|
||||
Should be sorted by priority (lowest first).
|
||||
cache_settings: Global cache settings including air-gap mode.
|
||||
config: Client configuration options.
|
||||
"""
|
||||
self.sources = sources or []
|
||||
self.cache_settings = cache_settings
|
||||
self.config = config or UpstreamClientConfig()
|
||||
|
||||
# Sort sources by priority (lower = higher priority)
|
||||
self.sources = sorted(self.sources, key=lambda s: s.priority)
|
||||
|
||||
def _get_allow_public_internet(self) -> bool:
|
||||
"""Get the allow_public_internet setting."""
|
||||
if self.cache_settings is None:
|
||||
return True # Default to allowing if no settings provided
|
||||
return self.cache_settings.allow_public_internet
|
||||
|
||||
def _match_source(self, url: str) -> Optional[UpstreamSource]:
|
||||
"""
|
||||
Find the upstream source that matches the given URL.
|
||||
|
||||
Matches by URL prefix, returns the highest priority match.
|
||||
|
||||
Args:
|
||||
url: The URL to match.
|
||||
|
||||
Returns:
|
||||
The matching UpstreamSource or None if no match.
|
||||
"""
|
||||
for source in self.sources:
|
||||
# Check if URL starts with source URL (prefix match)
|
||||
if url.startswith(source.url.rstrip("/")):
|
||||
return source
|
||||
|
||||
return None
|
||||
|
||||
def _build_auth_headers(self, source: UpstreamSource) -> dict:
|
||||
"""
|
||||
Build authentication headers for the given source.
|
||||
|
||||
Args:
|
||||
source: The upstream source with auth configuration.
|
||||
|
||||
Returns:
|
||||
Dictionary of headers to add to the request.
|
||||
"""
|
||||
headers = {}
|
||||
|
||||
if source.auth_type == "none":
|
||||
pass
|
||||
elif source.auth_type == "basic":
|
||||
# httpx handles basic auth via auth parameter, but we can also
|
||||
# do it manually if needed. We'll use the auth parameter instead.
|
||||
pass
|
||||
elif source.auth_type == "bearer":
|
||||
password = source.get_password()
|
||||
if password:
|
||||
headers["Authorization"] = f"Bearer {password}"
|
||||
elif source.auth_type == "api_key":
|
||||
# API key auth uses custom headers
|
||||
custom_headers = source.get_headers()
|
||||
if custom_headers:
|
||||
headers.update(custom_headers)
|
||||
|
||||
return headers
|
||||
|
||||
def _get_basic_auth(self, source: UpstreamSource) -> Optional[tuple[str, str]]:
|
||||
"""
|
||||
Get basic auth credentials if applicable.
|
||||
|
||||
Args:
|
||||
source: The upstream source.
|
||||
|
||||
Returns:
|
||||
Tuple of (username, password) or None.
|
||||
"""
|
||||
if source.auth_type == "basic" and source.username:
|
||||
password = source.get_password() or ""
|
||||
return (source.username, password)
|
||||
return None
|
||||
|
||||
def _should_retry(self, error: Exception, attempt: int) -> bool:
|
||||
"""
|
||||
Determine if a request should be retried.
|
||||
|
||||
Args:
|
||||
error: The exception that occurred.
|
||||
attempt: Current attempt number (0-indexed).
|
||||
|
||||
Returns:
|
||||
True if the request should be retried.
|
||||
"""
|
||||
if attempt >= self.config.max_retries - 1:
|
||||
return False
|
||||
|
||||
# Retry on connection errors and timeouts
|
||||
if isinstance(error, (httpx.ConnectError, httpx.ConnectTimeout)):
|
||||
return True
|
||||
|
||||
# Retry on read timeouts
|
||||
if isinstance(error, httpx.ReadTimeout):
|
||||
return True
|
||||
|
||||
# Retry on certain HTTP errors (502, 503, 504)
|
||||
if isinstance(error, httpx.HTTPStatusError):
|
||||
return error.response.status_code in (502, 503, 504)
|
||||
|
||||
return False
|
||||
|
||||
def _calculate_backoff(self, attempt: int) -> float:
|
||||
"""
|
||||
Calculate backoff delay for retry.
|
||||
|
||||
Uses exponential backoff with jitter.
|
||||
|
||||
Args:
|
||||
attempt: Current attempt number (0-indexed).
|
||||
|
||||
Returns:
|
||||
Delay in seconds.
|
||||
"""
|
||||
import random
|
||||
|
||||
delay = self.config.retry_backoff_base * (2**attempt)
|
||||
# Add jitter (±25%)
|
||||
delay *= 0.75 + random.random() * 0.5
|
||||
return min(delay, self.config.retry_backoff_max)
|
||||
|
||||
def fetch(self, url: str, expected_hash: Optional[str] = None) -> FetchResult:
|
||||
"""
|
||||
Fetch an artifact from the given URL.
|
||||
|
||||
Streams the response to a temp file while computing the SHA256 hash.
|
||||
Handles authentication, retries, and error cases.
|
||||
|
||||
Args:
|
||||
url: The URL to fetch.
|
||||
expected_hash: Optional expected SHA256 hash for verification.
|
||||
|
||||
Returns:
|
||||
FetchResult with content, hash, size, and headers.
|
||||
|
||||
Raises:
|
||||
AirGapError: If air-gap mode blocks the request.
|
||||
SourceDisabledError: If the matching source is disabled.
|
||||
UpstreamConnectionError: On connection failures.
|
||||
UpstreamTimeoutError: On timeout.
|
||||
UpstreamHTTPError: On HTTP error responses.
|
||||
UpstreamSSLError: On SSL/TLS errors.
|
||||
FileSizeExceededError: If Content-Length exceeds max_file_size.
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
# Match URL to source
|
||||
source = self._match_source(url)
|
||||
|
||||
# Check air-gap mode
|
||||
allow_public = self._get_allow_public_internet()
|
||||
|
||||
if not allow_public:
|
||||
if source is None:
|
||||
raise AirGapError(
|
||||
f"Air-gap mode enabled: URL does not match any configured upstream source: {url}"
|
||||
)
|
||||
if source.is_public:
|
||||
raise AirGapError(
|
||||
f"Air-gap mode enabled: Cannot fetch from public source '{source.name}'"
|
||||
)
|
||||
|
||||
# Check if source is enabled (if we have a match)
|
||||
if source is not None and not source.enabled:
|
||||
raise SourceDisabledError(
|
||||
f"Upstream source '{source.name}' is disabled"
|
||||
)
|
||||
|
||||
source_name = source.name if source else None
|
||||
logger.info(
|
||||
f"Fetching URL: {url} (source: {source_name or 'none'})"
|
||||
)
|
||||
|
||||
# Build request parameters
|
||||
headers = {"User-Agent": self.config.user_agent}
|
||||
auth = None
|
||||
|
||||
if source:
|
||||
headers.update(self._build_auth_headers(source))
|
||||
auth = self._get_basic_auth(source)
|
||||
|
||||
timeout = httpx.Timeout(
|
||||
connect=self.config.connect_timeout,
|
||||
read=self.config.read_timeout,
|
||||
write=30.0,
|
||||
pool=10.0,
|
||||
)
|
||||
|
||||
# Attempt fetch with retries
|
||||
last_error = None
|
||||
for attempt in range(self.config.max_retries):
|
||||
try:
|
||||
return self._do_fetch(
|
||||
url=url,
|
||||
headers=headers,
|
||||
auth=auth,
|
||||
timeout=timeout,
|
||||
source_name=source_name,
|
||||
start_time=start_time,
|
||||
expected_hash=expected_hash,
|
||||
)
|
||||
except (
|
||||
httpx.ConnectError,
|
||||
httpx.ConnectTimeout,
|
||||
httpx.ReadTimeout,
|
||||
httpx.HTTPStatusError,
|
||||
) as e:
|
||||
last_error = e
|
||||
if self._should_retry(e, attempt):
|
||||
delay = self._calculate_backoff(attempt)
|
||||
logger.warning(
|
||||
f"Fetch failed (attempt {attempt + 1}/{self.config.max_retries}), "
|
||||
f"retrying in {delay:.1f}s: {e}"
|
||||
)
|
||||
time.sleep(delay)
|
||||
else:
|
||||
break
|
||||
|
||||
# Convert final error to our exception types
|
||||
self._raise_upstream_error(last_error, url)
|
||||
|
||||
def _do_fetch(
|
||||
self,
|
||||
url: str,
|
||||
headers: dict,
|
||||
auth: Optional[tuple[str, str]],
|
||||
timeout: httpx.Timeout,
|
||||
source_name: Optional[str],
|
||||
start_time: float,
|
||||
expected_hash: Optional[str] = None,
|
||||
) -> FetchResult:
|
||||
"""
|
||||
Perform the actual fetch operation.
|
||||
|
||||
Args:
|
||||
url: URL to fetch.
|
||||
headers: Request headers.
|
||||
auth: Basic auth credentials or None.
|
||||
timeout: Request timeout configuration.
|
||||
source_name: Name of matched source for logging.
|
||||
start_time: Request start time for timing.
|
||||
expected_hash: Optional expected hash for verification.
|
||||
|
||||
Returns:
|
||||
FetchResult with content and metadata.
|
||||
"""
|
||||
with httpx.Client(
|
||||
timeout=timeout,
|
||||
follow_redirects=self.config.follow_redirects,
|
||||
max_redirects=self.config.max_redirects,
|
||||
verify=self.config.verify_ssl,
|
||||
) as client:
|
||||
with client.stream("GET", url, headers=headers, auth=auth) as response:
|
||||
# Check for HTTP errors
|
||||
response.raise_for_status()
|
||||
|
||||
# Check Content-Length against max size
|
||||
content_length = response.headers.get("content-length")
|
||||
if content_length:
|
||||
content_length = int(content_length)
|
||||
if (
|
||||
self.config.max_file_size
|
||||
and content_length > self.config.max_file_size
|
||||
):
|
||||
raise FileSizeExceededError(
|
||||
f"File size {content_length} exceeds maximum {self.config.max_file_size}",
|
||||
content_length,
|
||||
self.config.max_file_size,
|
||||
)
|
||||
|
||||
# Stream to temp file while computing hash
|
||||
hasher = hashlib.sha256()
|
||||
size = 0
|
||||
|
||||
# Create temp file
|
||||
temp_file = tempfile.NamedTemporaryFile(
|
||||
delete=False, prefix="orchard_upstream_"
|
||||
)
|
||||
temp_path = Path(temp_file.name)
|
||||
|
||||
try:
|
||||
for chunk in response.iter_bytes(chunk_size=65536):
|
||||
temp_file.write(chunk)
|
||||
hasher.update(chunk)
|
||||
size += len(chunk)
|
||||
|
||||
# Check size while streaming if max_file_size is set
|
||||
if self.config.max_file_size and size > self.config.max_file_size:
|
||||
temp_file.close()
|
||||
temp_path.unlink()
|
||||
raise FileSizeExceededError(
|
||||
f"Downloaded size {size} exceeds maximum {self.config.max_file_size}",
|
||||
size,
|
||||
self.config.max_file_size,
|
||||
)
|
||||
|
||||
temp_file.close()
|
||||
|
||||
sha256 = hasher.hexdigest()
|
||||
|
||||
# Verify hash if expected
|
||||
if expected_hash and sha256 != expected_hash.lower():
|
||||
temp_path.unlink()
|
||||
raise UpstreamError(
|
||||
f"Hash mismatch: expected {expected_hash}, got {sha256}"
|
||||
)
|
||||
|
||||
# Capture response headers
|
||||
response_headers = dict(response.headers)
|
||||
|
||||
# Get content type
|
||||
content_type = response.headers.get("content-type")
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
logger.info(
|
||||
f"Fetched {url}: {size} bytes, sha256={sha256[:12]}..., "
|
||||
f"source={source_name}, time={elapsed:.2f}s"
|
||||
)
|
||||
|
||||
# Return file handle positioned at start
|
||||
content = open(temp_path, "rb")
|
||||
|
||||
return FetchResult(
|
||||
content=content,
|
||||
sha256=sha256,
|
||||
size=size,
|
||||
content_type=content_type,
|
||||
response_headers=response_headers,
|
||||
source_name=source_name,
|
||||
temp_path=temp_path,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
# Clean up on error
|
||||
try:
|
||||
temp_file.close()
|
||||
except Exception:
|
||||
pass
|
||||
if temp_path.exists():
|
||||
temp_path.unlink()
|
||||
raise
|
||||
|
||||
def _raise_upstream_error(self, error: Exception, url: str):
|
||||
"""
|
||||
Convert httpx exception to appropriate UpstreamError.
|
||||
|
||||
Args:
|
||||
error: The httpx exception.
|
||||
url: The URL that was being fetched.
|
||||
|
||||
Raises:
|
||||
Appropriate UpstreamError subclass.
|
||||
"""
|
||||
if error is None:
|
||||
raise UpstreamError(f"Unknown error fetching {url}")
|
||||
|
||||
if isinstance(error, httpx.ConnectError):
|
||||
raise UpstreamConnectionError(
|
||||
f"Failed to connect to upstream: {error}"
|
||||
) from error
|
||||
|
||||
if isinstance(error, (httpx.ConnectTimeout, httpx.ReadTimeout)):
|
||||
raise UpstreamTimeoutError(
|
||||
f"Request timed out: {error}"
|
||||
) from error
|
||||
|
||||
if isinstance(error, httpx.HTTPStatusError):
|
||||
raise UpstreamHTTPError(
|
||||
f"HTTP {error.response.status_code}: {error}",
|
||||
error.response.status_code,
|
||||
dict(error.response.headers),
|
||||
) from error
|
||||
|
||||
# Check for SSL errors in the error chain
|
||||
if "ssl" in str(error).lower() or "certificate" in str(error).lower():
|
||||
raise UpstreamSSLError(f"SSL/TLS error: {error}") from error
|
||||
|
||||
raise UpstreamError(f"Error fetching {url}: {error}") from error
|
||||
|
||||
def test_connection(self, source: UpstreamSource) -> tuple[bool, Optional[str], Optional[int]]:
|
||||
"""
|
||||
Test connectivity to an upstream source.
|
||||
|
||||
Performs a HEAD request to the source URL to verify connectivity
|
||||
and authentication.
|
||||
|
||||
Args:
|
||||
source: The upstream source to test.
|
||||
|
||||
Returns:
|
||||
Tuple of (success, error_message, status_code).
|
||||
"""
|
||||
headers = {"User-Agent": self.config.user_agent}
|
||||
headers.update(self._build_auth_headers(source))
|
||||
auth = self._get_basic_auth(source)
|
||||
|
||||
timeout = httpx.Timeout(
|
||||
connect=self.config.connect_timeout,
|
||||
read=30.0,
|
||||
write=30.0,
|
||||
pool=10.0,
|
||||
)
|
||||
|
||||
try:
|
||||
with httpx.Client(
|
||||
timeout=timeout,
|
||||
verify=self.config.verify_ssl,
|
||||
) as client:
|
||||
response = client.head(
|
||||
source.url,
|
||||
headers=headers,
|
||||
auth=auth,
|
||||
follow_redirects=True,
|
||||
)
|
||||
# Consider 2xx and 3xx as success, also 405 (Method Not Allowed)
|
||||
# since some servers don't support HEAD
|
||||
if response.status_code < 400 or response.status_code == 405:
|
||||
return (True, None, response.status_code)
|
||||
else:
|
||||
return (
|
||||
False,
|
||||
f"HTTP {response.status_code}",
|
||||
response.status_code,
|
||||
)
|
||||
except httpx.ConnectError as e:
|
||||
return (False, f"Connection failed: {e}", None)
|
||||
except httpx.ConnectTimeout as e:
|
||||
return (False, f"Connection timed out: {e}", None)
|
||||
except httpx.ReadTimeout as e:
|
||||
return (False, f"Read timed out: {e}", None)
|
||||
except Exception as e:
|
||||
return (False, f"Error: {e}", None)
|
||||
Reference in New Issue
Block a user