package store import ( "context" "database/sql" "errors" "fmt" "strings" "time" "go.uber.org/zap" _ "modernc.org/sqlite" ) type Store struct { db *sql.DB log *zap.Logger } type User struct { ID int64 Username string PasswordHash string IsAdmin bool CreatedAt time.Time } type Site struct { ID int64 SSHUser string Host string Port int CreatedAt time.Time LastRunStatus sql.NullString LastRunOutput sql.NullString LastRunAt sql.NullTime LastScanAt sql.NullTime LastScanState sql.NullString LastScanNotes sql.NullString Targets []SiteTarget } type SiteTarget struct { Path string Mode string LastSizeByte sql.NullInt64 LastScanAt sql.NullTime LastError sql.NullString } type Job struct { ID int64 SiteID int64 Type string Status string Summary sql.NullString CreatedAt time.Time StartedAt sql.NullTime FinishedAt sql.NullTime } type JobEvent struct { JobID int64 Level string Message string OccurredAt time.Time } func Open(path string) (*Store, error) { dsn := fmt.Sprintf("file:%s?_pragma=foreign_keys(1)", path) db, err := sql.Open("sqlite", dsn) if err != nil { return nil, err } if err := db.Ping(); err != nil { return nil, err } s := &Store{db: db} if err := s.migrate(context.Background()); err != nil { return nil, err } return s, nil } func (s *Store) Close() error { return s.db.Close() } func (s *Store) SetLogger(logger *zap.Logger) { s.log = logger } func (s *Store) migrate(ctx context.Context) error { const usersSQL = ` CREATE TABLE IF NOT EXISTS users ( id INTEGER PRIMARY KEY AUTOINCREMENT, username TEXT NOT NULL UNIQUE, password_hash TEXT NOT NULL, is_admin INTEGER NOT NULL DEFAULT 0, created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP );` const sessionsSQL = ` CREATE TABLE IF NOT EXISTS sessions ( id INTEGER PRIMARY KEY AUTOINCREMENT, user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE, token_hash TEXT NOT NULL UNIQUE, created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, expires_at DATETIME NOT NULL );` const sitesSQL = ` CREATE TABLE IF NOT EXISTS sites ( id INTEGER PRIMARY KEY AUTOINCREMENT, ssh_user TEXT NOT NULL, host TEXT NOT NULL, port INTEGER NOT NULL DEFAULT 22, created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, last_run_status TEXT, last_run_output TEXT, last_run_at DATETIME, last_scan_at DATETIME, last_scan_state TEXT, last_scan_notes TEXT );` const siteTargetsSQL = ` CREATE TABLE IF NOT EXISTS site_targets ( id INTEGER PRIMARY KEY AUTOINCREMENT, site_id INTEGER NOT NULL REFERENCES sites(id) ON DELETE CASCADE, path TEXT NOT NULL, mode TEXT NOT NULL CHECK(mode IN ('directory', 'sqlite_dump')), last_size_bytes INTEGER, last_scan_at DATETIME, last_error TEXT, UNIQUE(site_id, path, mode) );` const jobsSQL = ` CREATE TABLE IF NOT EXISTS jobs ( id INTEGER PRIMARY KEY AUTOINCREMENT, site_id INTEGER NOT NULL REFERENCES sites(id) ON DELETE CASCADE, type TEXT NOT NULL, status TEXT NOT NULL CHECK(status IN ('queued', 'running', 'success', 'warning', 'failed')), summary TEXT, created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, started_at DATETIME, finished_at DATETIME );` const jobEventsSQL = ` CREATE TABLE IF NOT EXISTS job_events ( id INTEGER PRIMARY KEY AUTOINCREMENT, job_id INTEGER NOT NULL REFERENCES jobs(id) ON DELETE CASCADE, level TEXT NOT NULL, message TEXT NOT NULL, occurred_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP );` if _, err := s.db.ExecContext(ctx, usersSQL); err != nil { return err } if _, err := s.db.ExecContext(ctx, sessionsSQL); err != nil { return err } if _, err := s.db.ExecContext(ctx, sitesSQL); err != nil { return err } if _, err := s.db.ExecContext(ctx, siteTargetsSQL); err != nil { return err } if _, err := s.db.ExecContext(ctx, jobsSQL); err != nil { return err } if _, err := s.db.ExecContext(ctx, jobEventsSQL); err != nil { return err } s.debugDB("schema migrated") return nil } func (s *Store) CreateUser(ctx context.Context, username, passwordHash string) (User, error) { tx, err := s.db.BeginTx(ctx, nil) if err != nil { return User{}, err } defer tx.Rollback() var count int if err := tx.QueryRowContext(ctx, `SELECT COUNT(1) FROM users`).Scan(&count); err != nil { return User{}, err } isAdmin := count == 0 res, err := tx.ExecContext( ctx, `INSERT INTO users (username, password_hash, is_admin) VALUES (?, ?, ?)`, username, passwordHash, boolToInt(isAdmin), ) if err != nil { if isUniqueUsernameErr(err) { return User{}, ErrUsernameTaken } return User{}, err } userID, err := res.LastInsertId() if err != nil { return User{}, err } user, err := userByIDTx(ctx, tx, userID) if err != nil { return User{}, err } if err := tx.Commit(); err != nil { return User{}, err } s.debugDB("user created", zap.Int64("user_id", user.ID), zap.String("username", user.Username), zap.Bool("is_admin", user.IsAdmin)) return user, nil } func (s *Store) UserByUsername(ctx context.Context, username string) (User, error) { const q = ` SELECT id, username, password_hash, is_admin, created_at FROM users WHERE username = ?` return scanUser(s.db.QueryRowContext(ctx, q, username)) } func (s *Store) UserBySessionTokenHash(ctx context.Context, tokenHash string) (User, error) { const q = ` SELECT u.id, u.username, u.password_hash, u.is_admin, u.created_at FROM sessions s JOIN users u ON u.id = s.user_id WHERE s.token_hash = ? AND s.expires_at > CURRENT_TIMESTAMP` return scanUser(s.db.QueryRowContext(ctx, q, tokenHash)) } func (s *Store) CreateSession(ctx context.Context, userID int64, tokenHash string, expiresAt time.Time) error { _, err := s.db.ExecContext( ctx, `INSERT INTO sessions (user_id, token_hash, expires_at) VALUES (?, ?, ?)`, userID, tokenHash, expiresAt.UTC().Format(time.RFC3339), ) if err == nil { s.debugDB("session created", zap.Int64("user_id", userID), zap.Time("expires_at", expiresAt.UTC())) } return err } func (s *Store) DeleteSessionByTokenHash(ctx context.Context, tokenHash string) error { _, err := s.db.ExecContext(ctx, `DELETE FROM sessions WHERE token_hash = ?`, tokenHash) if err == nil { s.debugDB("session deleted") } return err } func (s *Store) TouchSessionByTokenHash(ctx context.Context, tokenHash string, expiresAt time.Time) error { res, err := s.db.ExecContext( ctx, `UPDATE sessions SET expires_at = ? WHERE token_hash = ? AND expires_at > CURRENT_TIMESTAMP`, expiresAt.UTC().Format(time.RFC3339), tokenHash, ) if err != nil { return err } rows, err := res.RowsAffected() if err != nil { return err } if rows == 0 { return sql.ErrNoRows } s.debugDB("session touched", zap.Time("expires_at", expiresAt.UTC())) return nil } func (s *Store) UpdateUserPasswordHash(ctx context.Context, userID int64, passwordHash string) error { _, err := s.db.ExecContext(ctx, `UPDATE users SET password_hash = ? WHERE id = ?`, passwordHash, userID) if err == nil { s.debugDB("user password updated", zap.Int64("user_id", userID)) } return err } func (s *Store) CreateSite(ctx context.Context, sshUser, host string, port int, targets []SiteTarget) (Site, error) { tx, err := s.db.BeginTx(ctx, nil) if err != nil { return Site{}, err } defer tx.Rollback() res, err := tx.ExecContext(ctx, `INSERT INTO sites (ssh_user, host, port) VALUES (?, ?, ?)`, sshUser, host, port) if err != nil { return Site{}, err } id, err := res.LastInsertId() if err != nil { return Site{}, err } for _, t := range targets { if _, err := tx.ExecContext( ctx, `INSERT INTO site_targets (site_id, path, mode) VALUES (?, ?, ?)`, id, t.Path, t.Mode, ); err != nil { return Site{}, err } } if err := tx.Commit(); err != nil { return Site{}, err } s.debugDB("site created", zap.Int64("site_id", id), zap.String("ssh_user", sshUser), zap.String("host", host), zap.Int("port", port), zap.Int("targets", len(targets))) return s.SiteByID(ctx, id) } func (s *Store) UpdateSite(ctx context.Context, id int64, sshUser, host string, port int, targets []SiteTarget) (Site, error) { tx, err := s.db.BeginTx(ctx, nil) if err != nil { return Site{}, err } defer tx.Rollback() res, err := tx.ExecContext(ctx, `UPDATE sites SET ssh_user = ?, host = ?, port = ? WHERE id = ?`, sshUser, host, port, id) if err != nil { return Site{}, err } affected, err := res.RowsAffected() if err != nil { return Site{}, err } if affected == 0 { return Site{}, sql.ErrNoRows } if _, err := tx.ExecContext(ctx, `DELETE FROM site_targets WHERE site_id = ?`, id); err != nil { return Site{}, err } for _, t := range targets { if _, err := tx.ExecContext( ctx, `INSERT INTO site_targets (site_id, path, mode) VALUES (?, ?, ?)`, id, t.Path, t.Mode, ); err != nil { return Site{}, err } } if err := tx.Commit(); err != nil { return Site{}, err } s.debugDB("site updated", zap.Int64("site_id", id), zap.String("ssh_user", sshUser), zap.String("host", host), zap.Int("port", port), zap.Int("targets", len(targets))) return s.SiteByID(ctx, id) } func (s *Store) DeleteSite(ctx context.Context, id int64) error { res, err := s.db.ExecContext(ctx, `DELETE FROM sites WHERE id = ?`, id) if err != nil { return err } affected, err := res.RowsAffected() if err != nil { return err } if affected == 0 { return sql.ErrNoRows } s.debugDB("site deleted", zap.Int64("site_id", id)) return nil } func (s *Store) ListSites(ctx context.Context) ([]Site, error) { const q = ` SELECT id, ssh_user, host, port, created_at, last_run_status, last_run_output, last_run_at, last_scan_at, last_scan_state, last_scan_notes FROM sites ORDER BY id DESC` rows, err := s.db.QueryContext(ctx, q) if err != nil { return nil, err } defer rows.Close() var out []Site for rows.Next() { site, err := scanSite(rows) if err != nil { return nil, err } out = append(out, site) } if err := rows.Err(); err != nil { return nil, err } if err := s.populateTargets(ctx, out); err != nil { return nil, err } return out, nil } func (s *Store) SiteByID(ctx context.Context, id int64) (Site, error) { const q = ` SELECT id, ssh_user, host, port, created_at, last_run_status, last_run_output, last_run_at, last_scan_at, last_scan_state, last_scan_notes FROM sites WHERE id = ?` site, err := scanSite(s.db.QueryRowContext(ctx, q, id)) if err != nil { return Site{}, err } targets, err := s.targetsBySiteID(ctx, id) if err != nil { return Site{}, err } site.Targets = targets return site, nil } func (s *Store) UpdateSiteRunResult(ctx context.Context, id int64, status, output string, at time.Time) error { _, err := s.db.ExecContext( ctx, `UPDATE sites SET last_run_status = ?, last_run_output = ?, last_run_at = ? WHERE id = ?`, status, output, at.UTC().Format(time.RFC3339), id, ) if err == nil { s.debugDB("site run updated", zap.Int64("site_id", id), zap.String("status", status), zap.Time("at", at.UTC())) } return err } func (s *Store) CreateJob(ctx context.Context, siteID int64, jobType string) (Job, error) { res, err := s.db.ExecContext( ctx, `INSERT INTO jobs (site_id, type, status) VALUES (?, ?, 'queued')`, siteID, jobType, ) if err != nil { return Job{}, err } id, err := res.LastInsertId() if err != nil { return Job{}, err } s.debugDB("job created", zap.Int64("job_id", id), zap.Int64("site_id", siteID), zap.String("job_type", jobType)) return s.JobByID(ctx, id) } func (s *Store) JobByID(ctx context.Context, id int64) (Job, error) { const q = ` SELECT id, site_id, type, status, summary, created_at, started_at, finished_at FROM jobs WHERE id = ?` return scanJob(s.db.QueryRowContext(ctx, q, id)) } func (s *Store) TryStartNextQueuedJob(ctx context.Context) (Job, bool, error) { tx, err := s.db.BeginTx(ctx, nil) if err != nil { return Job{}, false, err } defer tx.Rollback() var id int64 if err := tx.QueryRowContext(ctx, `SELECT id FROM jobs WHERE status = 'queued' ORDER BY id ASC LIMIT 1`).Scan(&id); err != nil { if errors.Is(err, sql.ErrNoRows) { return Job{}, false, nil } return Job{}, false, err } res, err := tx.ExecContext(ctx, `UPDATE jobs SET status = 'running', started_at = ? WHERE id = ? AND status = 'queued'`, time.Now().UTC().Format(time.RFC3339), id) if err != nil { return Job{}, false, err } affected, err := res.RowsAffected() if err != nil { return Job{}, false, err } if affected == 0 { return Job{}, false, nil } job, err := scanJob(tx.QueryRowContext(ctx, ` SELECT id, site_id, type, status, summary, created_at, started_at, finished_at FROM jobs WHERE id = ?`, id)) if err != nil { return Job{}, false, err } if err := tx.Commit(); err != nil { return Job{}, false, err } s.debugDB("job started", zap.Int64("job_id", job.ID), zap.Int64("site_id", job.SiteID), zap.String("job_type", job.Type)) return job, true, nil } func (s *Store) CompleteJob(ctx context.Context, jobID int64, status, summary string) error { _, err := s.db.ExecContext( ctx, `UPDATE jobs SET status = ?, summary = ?, finished_at = ? WHERE id = ?`, status, summary, time.Now().UTC().Format(time.RFC3339), jobID, ) if err == nil { s.debugDB("job completed", zap.Int64("job_id", jobID), zap.String("status", status), zap.String("summary", summary)) } return err } func (s *Store) AddJobEvent(ctx context.Context, event JobEvent) error { _, err := s.db.ExecContext( ctx, `INSERT INTO job_events (job_id, level, message) VALUES (?, ?, ?)`, event.JobID, event.Level, event.Message, ) if err == nil { s.debugDB("job event added", zap.Int64("job_id", event.JobID), zap.String("level", event.Level), zap.String("message", event.Message)) } return err } func (s *Store) ListRecentJobs(ctx context.Context, limit int) ([]Job, error) { if limit <= 0 { limit = 20 } rows, err := s.db.QueryContext(ctx, ` SELECT id, site_id, type, status, summary, created_at, started_at, finished_at FROM jobs ORDER BY id DESC LIMIT ?`, limit) if err != nil { return nil, err } defer rows.Close() var out []Job for rows.Next() { job, err := scanJob(rows) if err != nil { return nil, err } out = append(out, job) } if err := rows.Err(); err != nil { return nil, err } return out, nil } func (s *Store) UpdateSiteScanResult(ctx context.Context, siteID int64, state, notes string, scannedAt time.Time, targets []SiteTarget) error { tx, err := s.db.BeginTx(ctx, nil) if err != nil { return err } defer tx.Rollback() for _, t := range targets { if _, err := tx.ExecContext( ctx, `UPDATE site_targets SET last_size_bytes = ?, last_scan_at = ?, last_error = ? WHERE site_id = ? AND path = ? AND mode = ?`, nullInt64Arg(t.LastSizeByte), timeOrNil(t.LastScanAt), nullStringArg(t.LastError), siteID, t.Path, t.Mode, ); err != nil { return err } } if _, err := tx.ExecContext( ctx, `UPDATE sites SET last_scan_at = ?, last_scan_state = ?, last_scan_notes = ? WHERE id = ?`, scannedAt.UTC().Format(time.RFC3339), state, notes, siteID, ); err != nil { return err } s.debugDB("site scan updated", zap.Int64("site_id", siteID), zap.String("state", state), zap.Int("targets", len(targets)), zap.Time("scanned_at", scannedAt.UTC())) return tx.Commit() } func userByIDTx(ctx context.Context, tx *sql.Tx, id int64) (User, error) { const q = ` SELECT id, username, password_hash, is_admin, created_at FROM users WHERE id = ?` return scanUser(tx.QueryRowContext(ctx, q, id)) } type scanner interface { Scan(dest ...any) error } func scanJob(row scanner) (Job, error) { var job Job if err := row.Scan( &job.ID, &job.SiteID, &job.Type, &job.Status, &job.Summary, &job.CreatedAt, &job.StartedAt, &job.FinishedAt, ); err != nil { return Job{}, err } return job, nil } func scanUser(row scanner) (User, error) { var user User var isAdmin int if err := row.Scan(&user.ID, &user.Username, &user.PasswordHash, &isAdmin, &user.CreatedAt); err != nil { return User{}, err } user.IsAdmin = isAdmin == 1 return user, nil } func scanSite(row scanner) (Site, error) { var site Site if err := row.Scan( &site.ID, &site.SSHUser, &site.Host, &site.Port, &site.CreatedAt, &site.LastRunStatus, &site.LastRunOutput, &site.LastRunAt, &site.LastScanAt, &site.LastScanState, &site.LastScanNotes, ); err != nil { return Site{}, err } return site, nil } func (s *Store) populateTargets(ctx context.Context, sites []Site) error { if len(sites) == 0 { return nil } targetsBySite, err := s.allTargetsBySiteID(ctx) if err != nil { return err } for i := range sites { sites[i].Targets = targetsBySite[sites[i].ID] } return nil } func (s *Store) allTargetsBySiteID(ctx context.Context) (map[int64][]SiteTarget, error) { const q = `SELECT site_id, path, mode, last_size_bytes, last_scan_at, last_error FROM site_targets ORDER BY id ASC` rows, err := s.db.QueryContext(ctx, q) if err != nil { return nil, err } defer rows.Close() out := map[int64][]SiteTarget{} for rows.Next() { var siteID int64 var target SiteTarget if err := rows.Scan(&siteID, &target.Path, &target.Mode, &target.LastSizeByte, &target.LastScanAt, &target.LastError); err != nil { return nil, err } out[siteID] = append(out[siteID], target) } if err := rows.Err(); err != nil { return nil, err } return out, nil } func (s *Store) targetsBySiteID(ctx context.Context, siteID int64) ([]SiteTarget, error) { const q = `SELECT path, mode, last_size_bytes, last_scan_at, last_error FROM site_targets WHERE site_id = ? ORDER BY id ASC` rows, err := s.db.QueryContext(ctx, q, siteID) if err != nil { return nil, err } defer rows.Close() var out []SiteTarget for rows.Next() { var target SiteTarget if err := rows.Scan(&target.Path, &target.Mode, &target.LastSizeByte, &target.LastScanAt, &target.LastError); err != nil { return nil, err } out = append(out, target) } if err := rows.Err(); err != nil { return nil, err } return out, nil } func boolToInt(v bool) int { if v { return 1 } return 0 } var ErrUsernameTaken = errors.New("username already registered") func isUniqueUsernameErr(err error) bool { return strings.Contains(err.Error(), "UNIQUE constraint failed: users.username") } func nullInt64Arg(v sql.NullInt64) any { if v.Valid { return v.Int64 } return nil } func nullStringArg(v sql.NullString) any { if v.Valid { return v.String } return nil } func timeOrNil(v sql.NullTime) any { if v.Valid { return v.Time.UTC().Format(time.RFC3339) } return nil } func (s *Store) debugDB(msg string, fields ...zap.Field) { if s.log == nil { return } s.log.Debug(msg, fields...) }