Files
orchard/internal/storage/database.go
Mondo Diaz f255ae1d58 Add auto-migration on database startup
- Embed migrations/001_initial.sql into the binary
- Run migrations automatically when connecting to database
- Uses CREATE TABLE IF NOT EXISTS for idempotent execution

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-12-05 16:56:13 -06:00

364 lines
10 KiB
Go

package storage
import (
"context"
"database/sql"
_ "embed"
"fmt"
"time"
"gitlab.global.bsf.tools/esv/bsf/bsf-integration/orchard/orchard-mvp/internal/config"
"gitlab.global.bsf.tools/esv/bsf/bsf-integration/orchard/orchard-mvp/internal/models"
"github.com/google/uuid"
_ "github.com/lib/pq"
)
//go:embed migrations/001_initial.sql
var migrationSQL string
// Database handles all database operations
type Database struct {
db *sql.DB
}
// NewDatabase creates a new database connection and runs migrations
func NewDatabase(cfg *config.DatabaseConfig) (*Database, error) {
dsn := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s",
cfg.Host, cfg.Port, cfg.User, cfg.Password, cfg.DBName, cfg.SSLMode)
db, err := sql.Open("postgres", dsn)
if err != nil {
return nil, fmt.Errorf("failed to connect to database: %w", err)
}
if err := db.Ping(); err != nil {
return nil, fmt.Errorf("failed to ping database: %w", err)
}
d := &Database{db: db}
// Run migrations
if err := d.runMigrations(); err != nil {
return nil, fmt.Errorf("failed to run migrations: %w", err)
}
return d, nil
}
// runMigrations executes the embedded SQL migrations
func (d *Database) runMigrations() error {
_, err := d.db.Exec(migrationSQL)
return err
}
// Close closes the database connection
func (d *Database) Close() error {
return d.db.Close()
}
// Grove operations
func (d *Database) CreateGrove(ctx context.Context, grove *models.Grove) error {
grove.ID = uuid.New().String()
grove.CreatedAt = time.Now()
grove.UpdatedAt = grove.CreatedAt
_, err := d.db.ExecContext(ctx, `
INSERT INTO groves (id, name, description, is_public, created_at, updated_at, created_by)
VALUES ($1, $2, $3, $4, $5, $6, $7)
`, grove.ID, grove.Name, grove.Description, grove.IsPublic, grove.CreatedAt, grove.UpdatedAt, grove.CreatedBy)
return err
}
func (d *Database) GetGrove(ctx context.Context, name string) (*models.Grove, error) {
var grove models.Grove
err := d.db.QueryRowContext(ctx, `
SELECT id, name, description, is_public, created_at, updated_at, created_by
FROM groves WHERE name = $1
`, name).Scan(&grove.ID, &grove.Name, &grove.Description, &grove.IsPublic,
&grove.CreatedAt, &grove.UpdatedAt, &grove.CreatedBy)
if err == sql.ErrNoRows {
return nil, nil
}
return &grove, err
}
func (d *Database) ListGroves(ctx context.Context, userID string) ([]*models.Grove, error) {
rows, err := d.db.QueryContext(ctx, `
SELECT g.id, g.name, g.description, g.is_public, g.created_at, g.updated_at, g.created_by
FROM groves g
LEFT JOIN access_permissions ap ON g.id = ap.grove_id AND ap.user_id = $1
WHERE g.is_public = true OR ap.user_id IS NOT NULL
ORDER BY g.name
`, userID)
if err != nil {
return nil, err
}
defer rows.Close()
var groves []*models.Grove
for rows.Next() {
var grove models.Grove
if err := rows.Scan(&grove.ID, &grove.Name, &grove.Description, &grove.IsPublic,
&grove.CreatedAt, &grove.UpdatedAt, &grove.CreatedBy); err != nil {
return nil, err
}
groves = append(groves, &grove)
}
return groves, nil
}
// Tree operations
func (d *Database) CreateTree(ctx context.Context, tree *models.Tree) error {
tree.ID = uuid.New().String()
tree.CreatedAt = time.Now()
tree.UpdatedAt = tree.CreatedAt
_, err := d.db.ExecContext(ctx, `
INSERT INTO trees (id, grove_id, name, description, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6)
`, tree.ID, tree.GroveID, tree.Name, tree.Description, tree.CreatedAt, tree.UpdatedAt)
return err
}
func (d *Database) GetTree(ctx context.Context, groveID, name string) (*models.Tree, error) {
var tree models.Tree
err := d.db.QueryRowContext(ctx, `
SELECT id, grove_id, name, description, created_at, updated_at
FROM trees WHERE grove_id = $1 AND name = $2
`, groveID, name).Scan(&tree.ID, &tree.GroveID, &tree.Name, &tree.Description,
&tree.CreatedAt, &tree.UpdatedAt)
if err == sql.ErrNoRows {
return nil, nil
}
return &tree, err
}
func (d *Database) ListTrees(ctx context.Context, groveID string) ([]*models.Tree, error) {
rows, err := d.db.QueryContext(ctx, `
SELECT id, grove_id, name, description, created_at, updated_at
FROM trees WHERE grove_id = $1
ORDER BY name
`, groveID)
if err != nil {
return nil, err
}
defer rows.Close()
var trees []*models.Tree
for rows.Next() {
var tree models.Tree
if err := rows.Scan(&tree.ID, &tree.GroveID, &tree.Name, &tree.Description,
&tree.CreatedAt, &tree.UpdatedAt); err != nil {
return nil, err
}
trees = append(trees, &tree)
}
return trees, nil
}
// Fruit operations
func (d *Database) CreateFruit(ctx context.Context, fruit *models.Fruit) error {
fruit.CreatedAt = time.Now()
fruit.RefCount = 1
_, err := d.db.ExecContext(ctx, `
INSERT INTO fruits (id, size, content_type, original_name, created_at, created_by, ref_count, s3_key)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
ON CONFLICT (id) DO UPDATE SET ref_count = fruits.ref_count + 1
`, fruit.ID, fruit.Size, fruit.ContentType, fruit.OriginalName,
fruit.CreatedAt, fruit.CreatedBy, fruit.RefCount, fruit.S3Key)
return err
}
func (d *Database) GetFruit(ctx context.Context, id string) (*models.Fruit, error) {
var fruit models.Fruit
err := d.db.QueryRowContext(ctx, `
SELECT id, size, content_type, original_name, created_at, created_by, ref_count, s3_key
FROM fruits WHERE id = $1
`, id).Scan(&fruit.ID, &fruit.Size, &fruit.ContentType, &fruit.OriginalName,
&fruit.CreatedAt, &fruit.CreatedBy, &fruit.RefCount, &fruit.S3Key)
if err == sql.ErrNoRows {
return nil, nil
}
return &fruit, err
}
// Graft operations
func (d *Database) CreateGraft(ctx context.Context, graft *models.Graft) error {
graft.ID = uuid.New().String()
graft.CreatedAt = time.Now()
_, err := d.db.ExecContext(ctx, `
INSERT INTO grafts (id, tree_id, name, fruit_id, created_at, created_by)
VALUES ($1, $2, $3, $4, $5, $6)
ON CONFLICT (tree_id, name) DO UPDATE SET fruit_id = $4, created_at = $5, created_by = $6
`, graft.ID, graft.TreeID, graft.Name, graft.FruitID, graft.CreatedAt, graft.CreatedBy)
return err
}
func (d *Database) GetGraft(ctx context.Context, treeID, name string) (*models.Graft, error) {
var graft models.Graft
err := d.db.QueryRowContext(ctx, `
SELECT id, tree_id, name, fruit_id, created_at, created_by
FROM grafts WHERE tree_id = $1 AND name = $2
`, treeID, name).Scan(&graft.ID, &graft.TreeID, &graft.Name, &graft.FruitID,
&graft.CreatedAt, &graft.CreatedBy)
if err == sql.ErrNoRows {
return nil, nil
}
return &graft, err
}
func (d *Database) ListGrafts(ctx context.Context, treeID string) ([]*models.Graft, error) {
rows, err := d.db.QueryContext(ctx, `
SELECT id, tree_id, name, fruit_id, created_at, created_by
FROM grafts WHERE tree_id = $1
ORDER BY name
`, treeID)
if err != nil {
return nil, err
}
defer rows.Close()
var grafts []*models.Graft
for rows.Next() {
var graft models.Graft
if err := rows.Scan(&graft.ID, &graft.TreeID, &graft.Name, &graft.FruitID,
&graft.CreatedAt, &graft.CreatedBy); err != nil {
return nil, err
}
grafts = append(grafts, &graft)
}
return grafts, nil
}
// Harvest operations
func (d *Database) CreateHarvest(ctx context.Context, harvest *models.Harvest) error {
harvest.ID = uuid.New().String()
harvest.HarvestedAt = time.Now()
_, err := d.db.ExecContext(ctx, `
INSERT INTO harvests (id, fruit_id, tree_id, original_name, harvested_at, harvested_by, source_ip)
VALUES ($1, $2, $3, $4, $5, $6, $7)
`, harvest.ID, harvest.FruitID, harvest.TreeID, harvest.OriginalName,
harvest.HarvestedAt, harvest.HarvestedBy, harvest.SourceIP)
return err
}
// Audit operations
func (d *Database) CreateAuditLog(ctx context.Context, log *models.AuditLog) error {
log.ID = uuid.New().String()
log.Timestamp = time.Now()
_, err := d.db.ExecContext(ctx, `
INSERT INTO audit_logs (id, action, resource, user_id, details, timestamp, source_ip)
VALUES ($1, $2, $3, $4, $5, $6, $7)
`, log.ID, log.Action, log.Resource, log.UserID, log.Details, log.Timestamp, log.SourceIP)
return err
}
// Access control operations
func (d *Database) CheckAccess(ctx context.Context, groveID, userID, requiredLevel string) (bool, error) {
// Check if grove is public (read access for everyone)
var isPublic bool
err := d.db.QueryRowContext(ctx, `SELECT is_public FROM groves WHERE id = $1`, groveID).Scan(&isPublic)
if err != nil {
return false, err
}
if isPublic && requiredLevel == "read" {
return true, nil
}
// Check user-specific permissions
var level string
err = d.db.QueryRowContext(ctx, `
SELECT level FROM access_permissions
WHERE grove_id = $1 AND user_id = $2 AND (expires_at IS NULL OR expires_at > NOW())
`, groveID, userID).Scan(&level)
if err == sql.ErrNoRows {
return false, nil
}
if err != nil {
return false, err
}
// Check permission hierarchy: admin > write > read
switch requiredLevel {
case "read":
return true, nil
case "write":
return level == "write" || level == "admin", nil
case "admin":
return level == "admin", nil
}
return false, nil
}
func (d *Database) GrantAccess(ctx context.Context, perm *models.AccessPermission) error {
perm.ID = uuid.New().String()
perm.CreatedAt = time.Now()
_, err := d.db.ExecContext(ctx, `
INSERT INTO access_permissions (id, grove_id, user_id, level, created_at, expires_at)
VALUES ($1, $2, $3, $4, $5, $6)
ON CONFLICT (grove_id, user_id) DO UPDATE SET level = $4, expires_at = $6
`, perm.ID, perm.GroveID, perm.UserID, perm.Level, perm.CreatedAt, perm.ExpiresAt)
return err
}
// Consumer tracking
func (d *Database) TrackConsumer(ctx context.Context, treeID, projectURL string) error {
_, err := d.db.ExecContext(ctx, `
INSERT INTO consumers (id, tree_id, project_url, last_access, created_at)
VALUES ($1, $2, $3, NOW(), NOW())
ON CONFLICT (tree_id, project_url) DO UPDATE SET last_access = NOW()
`, uuid.New().String(), treeID, projectURL)
return err
}
func (d *Database) GetConsumers(ctx context.Context, treeID string) ([]*models.Consumer, error) {
rows, err := d.db.QueryContext(ctx, `
SELECT id, tree_id, project_url, last_access, created_at
FROM consumers WHERE tree_id = $1
ORDER BY last_access DESC
`, treeID)
if err != nil {
return nil, err
}
defer rows.Close()
var consumers []*models.Consumer
for rows.Next() {
var consumer models.Consumer
if err := rows.Scan(&consumer.ID, &consumer.TreeID, &consumer.ProjectURL,
&consumer.LastAccess, &consumer.CreatedAt); err != nil {
return nil, err
}
consumers = append(consumers, &consumer)
}
return consumers, nil
}