package store import ( "context" "crypto/rand" "database/sql" "encoding/hex" "errors" "fmt" "strings" "time" "go.uber.org/zap" _ "modernc.org/sqlite" ) type Store struct { db *sql.DB log *zap.Logger } type User struct { ID int64 Username string PasswordHash string IsAdmin bool CreatedAt time.Time } type Site struct { ID int64 SiteUUID string SSHUser string Host string Port int CreatedAt time.Time LastRunStatus sql.NullString LastRunOutput sql.NullString LastRunAt sql.NullTime LastScanAt sql.NullTime LastScanState sql.NullString LastScanNotes sql.NullString Targets []SiteTarget Filters []string } type SiteTarget struct { ID int64 Path string Mode string MySQLHost sql.NullString MySQLUser sql.NullString MySQLDB sql.NullString MySQLPassword sql.NullString LastSizeByte sql.NullInt64 LastScanAt sql.NullTime LastError sql.NullString } type Job struct { ID int64 SiteID int64 Type string Status string Summary sql.NullString CreatedAt time.Time StartedAt sql.NullTime FinishedAt sql.NullTime } type JobEvent struct { JobID int64 Level string Message string OccurredAt time.Time } type JobEventRecord struct { JobID int64 SiteID int64 JobType string Level string Message string OccurredAt time.Time } func Open(path string) (*Store, error) { dsn := fmt.Sprintf("file:%s?_pragma=foreign_keys(1)&_pragma=journal_mode(WAL)&_pragma=busy_timeout(5000)", path) db, err := sql.Open("sqlite", dsn) if err != nil { return nil, err } db.SetMaxOpenConns(8) db.SetMaxIdleConns(8) if err := db.Ping(); err != nil { return nil, err } s := &Store{db: db} if err := s.migrate(context.Background()); err != nil { return nil, err } return s, nil } func (s *Store) Close() error { return s.db.Close() } func (s *Store) SetLogger(logger *zap.Logger) { s.log = logger } func (s *Store) migrate(ctx context.Context) error { return s.runMigrations(ctx) } func (s *Store) CreateUser(ctx context.Context, username, passwordHash string) (User, error) { tx, err := s.db.BeginTx(ctx, nil) if err != nil { return User{}, err } defer tx.Rollback() var count int if err := tx.QueryRowContext(ctx, `SELECT COUNT(1) FROM users`).Scan(&count); err != nil { return User{}, err } isAdmin := count == 0 res, err := tx.ExecContext( ctx, `INSERT INTO users (username, password_hash, is_admin) VALUES (?, ?, ?)`, username, passwordHash, boolToInt(isAdmin), ) if err != nil { if isUniqueUsernameErr(err) { return User{}, ErrUsernameTaken } return User{}, err } userID, err := res.LastInsertId() if err != nil { return User{}, err } user, err := userByIDTx(ctx, tx, userID) if err != nil { return User{}, err } if err := tx.Commit(); err != nil { return User{}, err } s.debugDB("user created", zap.Int64("user_id", user.ID), zap.String("username", user.Username), zap.Bool("is_admin", user.IsAdmin)) return user, nil } func (s *Store) UserByUsername(ctx context.Context, username string) (User, error) { const q = ` SELECT id, username, password_hash, is_admin, created_at FROM users WHERE username = ?` return scanUser(s.db.QueryRowContext(ctx, q, username)) } func (s *Store) UserBySessionTokenHash(ctx context.Context, tokenHash string) (User, error) { const q = ` SELECT u.id, u.username, u.password_hash, u.is_admin, u.created_at FROM sessions s JOIN users u ON u.id = s.user_id WHERE s.token_hash = ? AND s.expires_at > CURRENT_TIMESTAMP` return scanUser(s.db.QueryRowContext(ctx, q, tokenHash)) } func (s *Store) CreateSession(ctx context.Context, userID int64, tokenHash string, expiresAt time.Time) error { _, err := s.db.ExecContext( ctx, `INSERT INTO sessions (user_id, token_hash, expires_at) VALUES (?, ?, ?)`, userID, tokenHash, expiresAt.UTC().Format(time.RFC3339), ) if err == nil { s.debugDB("session created", zap.Int64("user_id", userID), zap.Time("expires_at", expiresAt.UTC())) } return err } func (s *Store) DeleteSessionByTokenHash(ctx context.Context, tokenHash string) error { _, err := s.db.ExecContext(ctx, `DELETE FROM sessions WHERE token_hash = ?`, tokenHash) if err == nil { s.debugDB("session deleted") } return err } func (s *Store) TouchSessionByTokenHash(ctx context.Context, tokenHash string, expiresAt time.Time) error { res, err := s.db.ExecContext( ctx, `UPDATE sessions SET expires_at = ? WHERE token_hash = ? AND expires_at > CURRENT_TIMESTAMP`, expiresAt.UTC().Format(time.RFC3339), tokenHash, ) if err != nil { return err } rows, err := res.RowsAffected() if err != nil { return err } if rows == 0 { return sql.ErrNoRows } s.debugDB("session touched", zap.Time("expires_at", expiresAt.UTC())) return nil } func (s *Store) UpdateUserPasswordHash(ctx context.Context, userID int64, passwordHash string) error { _, err := s.db.ExecContext(ctx, `UPDATE users SET password_hash = ? WHERE id = ?`, passwordHash, userID) if err == nil { s.debugDB("user password updated", zap.Int64("user_id", userID)) } return err } func (s *Store) CreateSite(ctx context.Context, sshUser, host string, port int, targets []SiteTarget, filters []string) (Site, error) { tx, err := s.db.BeginTx(ctx, nil) if err != nil { return Site{}, err } defer tx.Rollback() siteUUID, err := newSiteUUID() if err != nil { return Site{}, err } res, err := tx.ExecContext(ctx, `INSERT INTO sites (site_uuid, ssh_user, host, port) VALUES (?, ?, ?, ?)`, siteUUID, sshUser, host, port) if err != nil { return Site{}, err } id, err := res.LastInsertId() if err != nil { return Site{}, err } for _, t := range targets { if _, err := tx.ExecContext( ctx, `INSERT INTO site_targets (site_id, path, mode, mysql_host, mysql_user, mysql_db, mysql_password) VALUES (?, ?, ?, ?, ?, ?, ?)`, id, t.Path, t.Mode, nullStringArg(t.MySQLHost), nullStringArg(t.MySQLUser), nullStringArg(t.MySQLDB), nullStringArg(t.MySQLPassword), ); err != nil { return Site{}, err } } for _, f := range filters { if _, err := tx.ExecContext(ctx, `INSERT INTO site_filters (site_id, pattern) VALUES (?, ?)`, id, f); err != nil { return Site{}, err } } if err := tx.Commit(); err != nil { return Site{}, err } s.debugDB("site created", zap.Int64("site_id", id), zap.String("site_uuid", siteUUID), zap.String("ssh_user", sshUser), zap.String("host", host), zap.Int("port", port), zap.Int("targets", len(targets)), zap.Int("filters", len(filters))) return s.SiteByID(ctx, id) } func (s *Store) UpdateSite(ctx context.Context, id int64, sshUser, host string, port int, targets []SiteTarget, filters []string) (Site, error) { tx, err := s.db.BeginTx(ctx, nil) if err != nil { return Site{}, err } defer tx.Rollback() res, err := tx.ExecContext(ctx, `UPDATE sites SET ssh_user = ?, host = ?, port = ? WHERE id = ?`, sshUser, host, port, id) if err != nil { return Site{}, err } affected, err := res.RowsAffected() if err != nil { return Site{}, err } if affected == 0 { return Site{}, sql.ErrNoRows } if _, err := tx.ExecContext(ctx, `DELETE FROM site_targets WHERE site_id = ?`, id); err != nil { return Site{}, err } for _, t := range targets { if _, err := tx.ExecContext( ctx, `INSERT INTO site_targets (site_id, path, mode, mysql_host, mysql_user, mysql_db, mysql_password) VALUES (?, ?, ?, ?, ?, ?, ?)`, id, t.Path, t.Mode, nullStringArg(t.MySQLHost), nullStringArg(t.MySQLUser), nullStringArg(t.MySQLDB), nullStringArg(t.MySQLPassword), ); err != nil { return Site{}, err } } if _, err := tx.ExecContext(ctx, `DELETE FROM site_filters WHERE site_id = ?`, id); err != nil { return Site{}, err } for _, f := range filters { if _, err := tx.ExecContext(ctx, `INSERT INTO site_filters (site_id, pattern) VALUES (?, ?)`, id, f); err != nil { return Site{}, err } } if err := tx.Commit(); err != nil { return Site{}, err } s.debugDB("site updated", zap.Int64("site_id", id), zap.String("ssh_user", sshUser), zap.String("host", host), zap.Int("port", port), zap.Int("targets", len(targets)), zap.Int("filters", len(filters))) return s.SiteByID(ctx, id) } func (s *Store) DeleteSite(ctx context.Context, id int64) error { res, err := s.db.ExecContext(ctx, `DELETE FROM sites WHERE id = ?`, id) if err != nil { return err } affected, err := res.RowsAffected() if err != nil { return err } if affected == 0 { return sql.ErrNoRows } s.debugDB("site deleted", zap.Int64("site_id", id)) return nil } func (s *Store) DeleteSiteTarget(ctx context.Context, siteID, targetID int64) error { res, err := s.db.ExecContext(ctx, `DELETE FROM site_targets WHERE id = ? AND site_id = ?`, targetID, siteID) if err != nil { return err } affected, err := res.RowsAffected() if err != nil { return err } if affected == 0 { return sql.ErrNoRows } s.debugDB("site target deleted", zap.Int64("site_id", siteID), zap.Int64("target_id", targetID)) return nil } func (s *Store) AddMySQLDumpTarget(ctx context.Context, siteID int64, dbHost, dbUser, dbName, dbPassword string) error { _, err := s.db.ExecContext( ctx, `INSERT INTO site_targets (site_id, path, mode, mysql_host, mysql_user, mysql_db, mysql_password) VALUES (?, ?, 'mysql_dump', ?, ?, ?, ?)`, siteID, dbName, dbHost, dbUser, dbName, dbPassword, ) if err == nil { s.debugDB("mysql dump target added", zap.Int64("site_id", siteID), zap.String("db_host", dbHost), zap.String("db_user", dbUser), zap.String("db_name", dbName)) } return err } func (s *Store) AddPodmanSaveTarget(ctx context.Context, siteID int64, imageName string) error { _, err := s.db.ExecContext( ctx, `INSERT INTO site_targets (site_id, path, mode, mysql_host, mysql_user, mysql_db, mysql_password) VALUES (?, ?, 'podman_save', NULL, NULL, NULL, NULL)`, siteID, imageName, ) if err == nil { s.debugDB("podman save target added", zap.Int64("site_id", siteID), zap.String("image_name", imageName)) } return err } func (s *Store) AddPodmanExportTarget(ctx context.Context, siteID int64, containerName string) error { _, err := s.db.ExecContext( ctx, `INSERT INTO site_targets (site_id, path, mode, mysql_host, mysql_user, mysql_db, mysql_password) VALUES (?, ?, 'podman_export', NULL, NULL, NULL, NULL)`, siteID, containerName, ) if err == nil { s.debugDB("podman export target added", zap.Int64("site_id", siteID), zap.String("container_name", containerName)) } return err } func (s *Store) ListSites(ctx context.Context) ([]Site, error) { const q = ` SELECT id, site_uuid, ssh_user, host, port, created_at, last_run_status, last_run_output, last_run_at, last_scan_at, last_scan_state, last_scan_notes FROM sites ORDER BY id DESC` rows, err := s.db.QueryContext(ctx, q) if err != nil { return nil, err } defer rows.Close() var out []Site for rows.Next() { site, err := scanSite(rows) if err != nil { return nil, err } out = append(out, site) } if err := rows.Err(); err != nil { return nil, err } if err := s.populateTargets(ctx, out); err != nil { return nil, err } if err := s.populateFilters(ctx, out); err != nil { return nil, err } return out, nil } func (s *Store) SiteByID(ctx context.Context, id int64) (Site, error) { const q = ` SELECT id, site_uuid, ssh_user, host, port, created_at, last_run_status, last_run_output, last_run_at, last_scan_at, last_scan_state, last_scan_notes FROM sites WHERE id = ?` site, err := scanSite(s.db.QueryRowContext(ctx, q, id)) if err != nil { return Site{}, err } targets, err := s.targetsBySiteID(ctx, id) if err != nil { return Site{}, err } site.Targets = targets filters, err := s.filtersBySiteID(ctx, id) if err != nil { return Site{}, err } site.Filters = filters return site, nil } func (s *Store) UpdateSiteRunResult(ctx context.Context, id int64, status, output string, at time.Time) error { _, err := s.db.ExecContext( ctx, `UPDATE sites SET last_run_status = ?, last_run_output = ?, last_run_at = ? WHERE id = ?`, status, output, at.UTC().Format(time.RFC3339), id, ) if err == nil { s.debugDB("site run updated", zap.Int64("site_id", id), zap.String("status", status), zap.Time("at", at.UTC())) } return err } func (s *Store) CreateJob(ctx context.Context, siteID int64, jobType string) (Job, error) { res, err := s.db.ExecContext( ctx, `INSERT INTO jobs (site_id, type, status) VALUES (?, ?, 'queued')`, siteID, jobType, ) if err != nil { return Job{}, err } id, err := res.LastInsertId() if err != nil { return Job{}, err } s.debugDB("job created", zap.Int64("job_id", id), zap.Int64("site_id", siteID), zap.String("job_type", jobType)) return s.JobByID(ctx, id) } func (s *Store) JobByID(ctx context.Context, id int64) (Job, error) { const q = ` SELECT id, site_id, type, status, summary, created_at, started_at, finished_at FROM jobs WHERE id = ?` return scanJob(s.db.QueryRowContext(ctx, q, id)) } func (s *Store) TryStartNextQueuedJob(ctx context.Context) (Job, bool, error) { tx, err := s.db.BeginTx(ctx, nil) if err != nil { return Job{}, false, err } defer tx.Rollback() var id int64 if err := tx.QueryRowContext(ctx, `SELECT id FROM jobs WHERE status = 'queued' ORDER BY id ASC LIMIT 1`).Scan(&id); err != nil { if errors.Is(err, sql.ErrNoRows) { return Job{}, false, nil } return Job{}, false, err } res, err := tx.ExecContext(ctx, `UPDATE jobs SET status = 'running', started_at = ? WHERE id = ? AND status = 'queued'`, time.Now().UTC().Format(time.RFC3339), id) if err != nil { return Job{}, false, err } affected, err := res.RowsAffected() if err != nil { return Job{}, false, err } if affected == 0 { return Job{}, false, nil } job, err := scanJob(tx.QueryRowContext(ctx, ` SELECT id, site_id, type, status, summary, created_at, started_at, finished_at FROM jobs WHERE id = ?`, id)) if err != nil { return Job{}, false, err } if err := tx.Commit(); err != nil { return Job{}, false, err } s.debugDB("job started", zap.Int64("job_id", job.ID), zap.Int64("site_id", job.SiteID), zap.String("job_type", job.Type)) return job, true, nil } func (s *Store) CompleteJob(ctx context.Context, jobID int64, status, summary string) error { _, err := s.db.ExecContext( ctx, `UPDATE jobs SET status = ?, summary = ?, finished_at = ? WHERE id = ?`, status, summary, time.Now().UTC().Format(time.RFC3339), jobID, ) if err == nil { s.debugDB("job completed", zap.Int64("job_id", jobID), zap.String("status", status), zap.String("summary", summary)) } return err } func (s *Store) CancelQueuedJob(ctx context.Context, jobID int64, summary string) (bool, error) { res, err := s.db.ExecContext( ctx, `UPDATE jobs SET status = 'failed', summary = ?, finished_at = ? WHERE id = ? AND status = 'queued'`, summary, time.Now().UTC().Format(time.RFC3339), jobID, ) if err != nil { return false, err } rows, err := res.RowsAffected() if err != nil { return false, err } if rows > 0 { s.debugDB("job canceled from queue", zap.Int64("job_id", jobID), zap.String("summary", summary)) return true, nil } return false, nil } func (s *Store) AddJobEvent(ctx context.Context, event JobEvent) error { _, err := s.db.ExecContext( ctx, `INSERT INTO job_events (job_id, level, message) VALUES (?, ?, ?)`, event.JobID, event.Level, event.Message, ) if err == nil { s.debugDB("job event added", zap.Int64("job_id", event.JobID), zap.String("level", event.Level), zap.String("message", event.Message)) } return err } func (s *Store) ListRecentJobs(ctx context.Context, limit int) ([]Job, error) { if limit <= 0 { limit = 20 } rows, err := s.db.QueryContext(ctx, ` SELECT id, site_id, type, status, summary, created_at, started_at, finished_at FROM jobs ORDER BY id DESC LIMIT ?`, limit) if err != nil { return nil, err } defer rows.Close() var out []Job for rows.Next() { job, err := scanJob(rows) if err != nil { return nil, err } out = append(out, job) } if err := rows.Err(); err != nil { return nil, err } return out, nil } func (s *Store) ListRecentJobEvents(ctx context.Context, limit int) ([]JobEventRecord, error) { if limit <= 0 { limit = 50 } rows, err := s.db.QueryContext(ctx, ` SELECT je.job_id, j.site_id, j.type, je.level, je.message, je.occurred_at FROM job_events je JOIN jobs j ON j.id = je.job_id ORDER BY je.id DESC LIMIT ?`, limit) if err != nil { return nil, err } defer rows.Close() var out []JobEventRecord for rows.Next() { var ev JobEventRecord if err := rows.Scan(&ev.JobID, &ev.SiteID, &ev.JobType, &ev.Level, &ev.Message, &ev.OccurredAt); err != nil { return nil, err } out = append(out, ev) } if err := rows.Err(); err != nil { return nil, err } return out, nil } func (s *Store) UpdateSiteScanResult(ctx context.Context, siteID int64, state, notes string, scannedAt time.Time, targets []SiteTarget) error { tx, err := s.db.BeginTx(ctx, nil) if err != nil { return err } defer tx.Rollback() for _, t := range targets { if _, err := tx.ExecContext( ctx, `UPDATE site_targets SET last_size_bytes = ?, last_scan_at = ?, last_error = ? WHERE site_id = ? AND path = ? AND mode = ?`, nullInt64Arg(t.LastSizeByte), timeOrNil(t.LastScanAt), nullStringArg(t.LastError), siteID, t.Path, t.Mode, ); err != nil { return err } } if _, err := tx.ExecContext( ctx, `UPDATE sites SET last_scan_at = ?, last_scan_state = ?, last_scan_notes = ? WHERE id = ?`, scannedAt.UTC().Format(time.RFC3339), state, notes, siteID, ); err != nil { return err } s.debugDB("site scan updated", zap.Int64("site_id", siteID), zap.String("state", state), zap.Int("targets", len(targets)), zap.Time("scanned_at", scannedAt.UTC())) return tx.Commit() } func userByIDTx(ctx context.Context, tx *sql.Tx, id int64) (User, error) { const q = ` SELECT id, username, password_hash, is_admin, created_at FROM users WHERE id = ?` return scanUser(tx.QueryRowContext(ctx, q, id)) } type scanner interface { Scan(dest ...any) error } func scanJob(row scanner) (Job, error) { var job Job if err := row.Scan( &job.ID, &job.SiteID, &job.Type, &job.Status, &job.Summary, &job.CreatedAt, &job.StartedAt, &job.FinishedAt, ); err != nil { return Job{}, err } return job, nil } func scanUser(row scanner) (User, error) { var user User var isAdmin int if err := row.Scan(&user.ID, &user.Username, &user.PasswordHash, &isAdmin, &user.CreatedAt); err != nil { return User{}, err } user.IsAdmin = isAdmin == 1 return user, nil } func scanSite(row scanner) (Site, error) { var site Site if err := row.Scan( &site.ID, &site.SiteUUID, &site.SSHUser, &site.Host, &site.Port, &site.CreatedAt, &site.LastRunStatus, &site.LastRunOutput, &site.LastRunAt, &site.LastScanAt, &site.LastScanState, &site.LastScanNotes, ); err != nil { return Site{}, err } return site, nil } func (s *Store) populateTargets(ctx context.Context, sites []Site) error { if len(sites) == 0 { return nil } targetsBySite, err := s.allTargetsBySiteID(ctx) if err != nil { return err } for i := range sites { sites[i].Targets = targetsBySite[sites[i].ID] } return nil } func (s *Store) populateFilters(ctx context.Context, sites []Site) error { if len(sites) == 0 { return nil } filtersBySite, err := s.allFiltersBySiteID(ctx) if err != nil { return err } for i := range sites { sites[i].Filters = filtersBySite[sites[i].ID] } return nil } func (s *Store) allTargetsBySiteID(ctx context.Context) (map[int64][]SiteTarget, error) { const q = `SELECT site_id, id, path, mode, mysql_host, mysql_user, mysql_db, mysql_password, last_size_bytes, last_scan_at, last_error FROM site_targets ORDER BY id ASC` rows, err := s.db.QueryContext(ctx, q) if err != nil { return nil, err } defer rows.Close() out := map[int64][]SiteTarget{} for rows.Next() { var siteID int64 var target SiteTarget if err := rows.Scan(&siteID, &target.ID, &target.Path, &target.Mode, &target.MySQLHost, &target.MySQLUser, &target.MySQLDB, &target.MySQLPassword, &target.LastSizeByte, &target.LastScanAt, &target.LastError); err != nil { return nil, err } out[siteID] = append(out[siteID], target) } if err := rows.Err(); err != nil { return nil, err } return out, nil } func (s *Store) targetsBySiteID(ctx context.Context, siteID int64) ([]SiteTarget, error) { const q = `SELECT id, path, mode, mysql_host, mysql_user, mysql_db, mysql_password, last_size_bytes, last_scan_at, last_error FROM site_targets WHERE site_id = ? ORDER BY id ASC` rows, err := s.db.QueryContext(ctx, q, siteID) if err != nil { return nil, err } defer rows.Close() var out []SiteTarget for rows.Next() { var target SiteTarget if err := rows.Scan(&target.ID, &target.Path, &target.Mode, &target.MySQLHost, &target.MySQLUser, &target.MySQLDB, &target.MySQLPassword, &target.LastSizeByte, &target.LastScanAt, &target.LastError); err != nil { return nil, err } out = append(out, target) } if err := rows.Err(); err != nil { return nil, err } return out, nil } func (s *Store) allFiltersBySiteID(ctx context.Context) (map[int64][]string, error) { rows, err := s.db.QueryContext(ctx, `SELECT site_id, pattern FROM site_filters ORDER BY id ASC`) if err != nil { return nil, err } defer rows.Close() out := map[int64][]string{} for rows.Next() { var siteID int64 var pattern string if err := rows.Scan(&siteID, &pattern); err != nil { return nil, err } out[siteID] = append(out[siteID], pattern) } if err := rows.Err(); err != nil { return nil, err } return out, nil } func (s *Store) filtersBySiteID(ctx context.Context, siteID int64) ([]string, error) { rows, err := s.db.QueryContext(ctx, `SELECT pattern FROM site_filters WHERE site_id = ? ORDER BY id ASC`, siteID) if err != nil { return nil, err } defer rows.Close() var out []string for rows.Next() { var pattern string if err := rows.Scan(&pattern); err != nil { return nil, err } out = append(out, pattern) } if err := rows.Err(); err != nil { return nil, err } return out, nil } func boolToInt(v bool) int { if v { return 1 } return 0 } var ErrUsernameTaken = errors.New("username already registered") func isUniqueUsernameErr(err error) bool { return strings.Contains(err.Error(), "UNIQUE constraint failed: users.username") } func nullInt64Arg(v sql.NullInt64) any { if v.Valid { return v.Int64 } return nil } func nullStringArg(v sql.NullString) any { if v.Valid { return v.String } return nil } func timeOrNil(v sql.NullTime) any { if v.Valid { return v.Time.UTC().Format(time.RFC3339) } return nil } func (s *Store) debugDB(msg string, fields ...zap.Field) { if s.log == nil { return } s.log.Debug(msg, fields...) } func newSiteUUID() (string, error) { buf := make([]byte, 16) if _, err := rand.Read(buf); err != nil { return "", err } return hex.EncodeToString(buf), nil }