cairn/internal/database/migrate.go

84 lines
1.9 KiB
Go

package database
import (
"context"
"embed"
"fmt"
"io/fs"
"log"
"sort"
"strings"
"github.com/jackc/pgx/v5/pgxpool"
)
//go:embed migrations/*.sql
var migrationsFS embed.FS
func Migrate(ctx context.Context, pool *pgxpool.Pool) error {
_, err := pool.Exec(ctx, `
CREATE TABLE IF NOT EXISTS schema_migrations (
version TEXT PRIMARY KEY,
applied_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
)
`)
if err != nil {
return fmt.Errorf("creating migrations table: %w", err)
}
entries, err := fs.ReadDir(migrationsFS, "migrations")
if err != nil {
return fmt.Errorf("reading migrations directory: %w", err)
}
// Sort by filename to ensure order.
sort.Slice(entries, func(i, j int) bool {
return entries[i].Name() < entries[j].Name()
})
for _, entry := range entries {
if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".sql") {
continue
}
version := strings.TrimSuffix(entry.Name(), ".sql")
var exists bool
err := pool.QueryRow(ctx, "SELECT EXISTS(SELECT 1 FROM schema_migrations WHERE version = $1)", version).Scan(&exists)
if err != nil {
return fmt.Errorf("checking migration %s: %w", version, err)
}
if exists {
continue
}
sql, err := migrationsFS.ReadFile("migrations/" + entry.Name())
if err != nil {
return fmt.Errorf("reading migration %s: %w", version, err)
}
tx, err := pool.Begin(ctx)
if err != nil {
return fmt.Errorf("beginning transaction for %s: %w", version, err)
}
if _, err := tx.Exec(ctx, string(sql)); err != nil {
tx.Rollback(ctx)
return fmt.Errorf("executing migration %s: %w", version, err)
}
if _, err := tx.Exec(ctx, "INSERT INTO schema_migrations (version) VALUES ($1)", version); err != nil {
tx.Rollback(ctx)
return fmt.Errorf("recording migration %s: %w", version, err)
}
if err := tx.Commit(ctx); err != nil {
return fmt.Errorf("committing migration %s: %w", version, err)
}
log.Printf("Applied migration: %s", version)
}
return nil
}