package store import ( "context" "database/sql" "errors" "fmt" "strings" "time" _ "modernc.org/sqlite" ) type Store struct { db *sql.DB } type User struct { ID int64 Username string PasswordHash string IsAdmin bool CreatedAt time.Time } type Site struct { ID int64 SSHUser string Host string Port int CreatedAt time.Time LastRunStatus sql.NullString LastRunOutput sql.NullString LastRunAt sql.NullTime LastScanAt sql.NullTime LastScanState sql.NullString LastScanNotes sql.NullString Targets []SiteTarget } type SiteTarget struct { Path string Mode string LastSizeByte sql.NullInt64 LastScanAt sql.NullTime LastError sql.NullString } func Open(path string) (*Store, error) { dsn := fmt.Sprintf("file:%s?_pragma=foreign_keys(1)", path) db, err := sql.Open("sqlite", dsn) if err != nil { return nil, err } if err := db.Ping(); err != nil { return nil, err } s := &Store{db: db} if err := s.migrate(context.Background()); err != nil { return nil, err } return s, nil } func (s *Store) Close() error { return s.db.Close() } func (s *Store) migrate(ctx context.Context) error { const usersSQL = ` CREATE TABLE IF NOT EXISTS users ( id INTEGER PRIMARY KEY AUTOINCREMENT, username TEXT NOT NULL UNIQUE, password_hash TEXT NOT NULL, is_admin INTEGER NOT NULL DEFAULT 0, created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP );` const sessionsSQL = ` CREATE TABLE IF NOT EXISTS sessions ( id INTEGER PRIMARY KEY AUTOINCREMENT, user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE, token_hash TEXT NOT NULL UNIQUE, created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, expires_at DATETIME NOT NULL );` const sitesSQL = ` CREATE TABLE IF NOT EXISTS sites ( id INTEGER PRIMARY KEY AUTOINCREMENT, ssh_user TEXT NOT NULL, host TEXT NOT NULL, port INTEGER NOT NULL DEFAULT 22, created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, last_run_status TEXT, last_run_output TEXT, last_run_at DATETIME, last_scan_at DATETIME, last_scan_state TEXT, last_scan_notes TEXT );` const siteTargetsSQL = ` CREATE TABLE IF NOT EXISTS site_targets ( id INTEGER PRIMARY KEY AUTOINCREMENT, site_id INTEGER NOT NULL REFERENCES sites(id) ON DELETE CASCADE, path TEXT NOT NULL, mode TEXT NOT NULL CHECK(mode IN ('directory', 'sqlite_dump')), last_size_bytes INTEGER, last_scan_at DATETIME, last_error TEXT );` if _, err := s.db.ExecContext(ctx, usersSQL); err != nil { return err } if _, err := s.db.ExecContext(ctx, sessionsSQL); err != nil { return err } if _, err := s.db.ExecContext(ctx, sitesSQL); err != nil { return err } if _, err := s.db.ExecContext(ctx, siteTargetsSQL); err != nil { return err } return nil } func (s *Store) CreateUser(ctx context.Context, username, passwordHash string) (User, error) { tx, err := s.db.BeginTx(ctx, nil) if err != nil { return User{}, err } defer tx.Rollback() var count int if err := tx.QueryRowContext(ctx, `SELECT COUNT(1) FROM users`).Scan(&count); err != nil { return User{}, err } isAdmin := count == 0 res, err := tx.ExecContext( ctx, `INSERT INTO users (username, password_hash, is_admin) VALUES (?, ?, ?)`, username, passwordHash, boolToInt(isAdmin), ) if err != nil { if isUniqueUsernameErr(err) { return User{}, ErrUsernameTaken } return User{}, err } userID, err := res.LastInsertId() if err != nil { return User{}, err } user, err := userByIDTx(ctx, tx, userID) if err != nil { return User{}, err } if err := tx.Commit(); err != nil { return User{}, err } return user, nil } func (s *Store) UserByUsername(ctx context.Context, username string) (User, error) { const q = ` SELECT id, username, password_hash, is_admin, created_at FROM users WHERE username = ?` return scanUser(s.db.QueryRowContext(ctx, q, username)) } func (s *Store) UserBySessionTokenHash(ctx context.Context, tokenHash string) (User, error) { const q = ` SELECT u.id, u.username, u.password_hash, u.is_admin, u.created_at FROM sessions s JOIN users u ON u.id = s.user_id WHERE s.token_hash = ? AND s.expires_at > CURRENT_TIMESTAMP` return scanUser(s.db.QueryRowContext(ctx, q, tokenHash)) } func (s *Store) CreateSession(ctx context.Context, userID int64, tokenHash string, expiresAt time.Time) error { _, err := s.db.ExecContext( ctx, `INSERT INTO sessions (user_id, token_hash, expires_at) VALUES (?, ?, ?)`, userID, tokenHash, expiresAt.UTC().Format(time.RFC3339), ) return err } func (s *Store) DeleteSessionByTokenHash(ctx context.Context, tokenHash string) error { _, err := s.db.ExecContext(ctx, `DELETE FROM sessions WHERE token_hash = ?`, tokenHash) return err } func (s *Store) TouchSessionByTokenHash(ctx context.Context, tokenHash string, expiresAt time.Time) error { res, err := s.db.ExecContext( ctx, `UPDATE sessions SET expires_at = ? WHERE token_hash = ? AND expires_at > CURRENT_TIMESTAMP`, expiresAt.UTC().Format(time.RFC3339), tokenHash, ) if err != nil { return err } rows, err := res.RowsAffected() if err != nil { return err } if rows == 0 { return sql.ErrNoRows } return nil } func (s *Store) UpdateUserPasswordHash(ctx context.Context, userID int64, passwordHash string) error { _, err := s.db.ExecContext(ctx, `UPDATE users SET password_hash = ? WHERE id = ?`, passwordHash, userID) return err } func (s *Store) CreateSite(ctx context.Context, sshUser, host string, port int, targets []SiteTarget) (Site, error) { tx, err := s.db.BeginTx(ctx, nil) if err != nil { return Site{}, err } defer tx.Rollback() res, err := tx.ExecContext(ctx, `INSERT INTO sites (ssh_user, host, port) VALUES (?, ?, ?)`, sshUser, host, port) if err != nil { return Site{}, err } id, err := res.LastInsertId() if err != nil { return Site{}, err } for _, t := range targets { if _, err := tx.ExecContext( ctx, `INSERT INTO site_targets (site_id, path, mode) VALUES (?, ?, ?)`, id, t.Path, t.Mode, ); err != nil { return Site{}, err } } if err := tx.Commit(); err != nil { return Site{}, err } return s.SiteByID(ctx, id) } func (s *Store) ListSites(ctx context.Context) ([]Site, error) { const q = ` SELECT id, ssh_user, host, port, created_at, last_run_status, last_run_output, last_run_at, last_scan_at, last_scan_state, last_scan_notes FROM sites ORDER BY id DESC` rows, err := s.db.QueryContext(ctx, q) if err != nil { return nil, err } defer rows.Close() var out []Site for rows.Next() { site, err := scanSite(rows) if err != nil { return nil, err } out = append(out, site) } if err := rows.Err(); err != nil { return nil, err } if err := s.populateTargets(ctx, out); err != nil { return nil, err } return out, nil } func (s *Store) SiteByID(ctx context.Context, id int64) (Site, error) { const q = ` SELECT id, ssh_user, host, port, created_at, last_run_status, last_run_output, last_run_at, last_scan_at, last_scan_state, last_scan_notes FROM sites WHERE id = ?` site, err := scanSite(s.db.QueryRowContext(ctx, q, id)) if err != nil { return Site{}, err } targets, err := s.targetsBySiteID(ctx, id) if err != nil { return Site{}, err } site.Targets = targets return site, nil } func (s *Store) UpdateSiteRunResult(ctx context.Context, id int64, status, output string, at time.Time) error { _, err := s.db.ExecContext( ctx, `UPDATE sites SET last_run_status = ?, last_run_output = ?, last_run_at = ? WHERE id = ?`, status, output, at.UTC().Format(time.RFC3339), id, ) return err } func (s *Store) UpdateSiteScanResult(ctx context.Context, siteID int64, state, notes string, scannedAt time.Time, targets []SiteTarget) error { tx, err := s.db.BeginTx(ctx, nil) if err != nil { return err } defer tx.Rollback() for _, t := range targets { if _, err := tx.ExecContext( ctx, `UPDATE site_targets SET last_size_bytes = ?, last_scan_at = ?, last_error = ? WHERE site_id = ? AND path = ? AND mode = ?`, nullInt64Arg(t.LastSizeByte), timeOrNil(t.LastScanAt), nullStringArg(t.LastError), siteID, t.Path, t.Mode, ); err != nil { return err } } if _, err := tx.ExecContext( ctx, `UPDATE sites SET last_scan_at = ?, last_scan_state = ?, last_scan_notes = ? WHERE id = ?`, scannedAt.UTC().Format(time.RFC3339), state, notes, siteID, ); err != nil { return err } return tx.Commit() } func userByIDTx(ctx context.Context, tx *sql.Tx, id int64) (User, error) { const q = ` SELECT id, username, password_hash, is_admin, created_at FROM users WHERE id = ?` return scanUser(tx.QueryRowContext(ctx, q, id)) } type scanner interface { Scan(dest ...any) error } func scanUser(row scanner) (User, error) { var user User var isAdmin int if err := row.Scan(&user.ID, &user.Username, &user.PasswordHash, &isAdmin, &user.CreatedAt); err != nil { return User{}, err } user.IsAdmin = isAdmin == 1 return user, nil } func scanSite(row scanner) (Site, error) { var site Site if err := row.Scan( &site.ID, &site.SSHUser, &site.Host, &site.Port, &site.CreatedAt, &site.LastRunStatus, &site.LastRunOutput, &site.LastRunAt, &site.LastScanAt, &site.LastScanState, &site.LastScanNotes, ); err != nil { return Site{}, err } return site, nil } func (s *Store) populateTargets(ctx context.Context, sites []Site) error { if len(sites) == 0 { return nil } targetsBySite, err := s.allTargetsBySiteID(ctx) if err != nil { return err } for i := range sites { sites[i].Targets = targetsBySite[sites[i].ID] } return nil } func (s *Store) allTargetsBySiteID(ctx context.Context) (map[int64][]SiteTarget, error) { const q = `SELECT site_id, path, mode, last_size_bytes, last_scan_at, last_error FROM site_targets ORDER BY id ASC` rows, err := s.db.QueryContext(ctx, q) if err != nil { return nil, err } defer rows.Close() out := map[int64][]SiteTarget{} for rows.Next() { var siteID int64 var target SiteTarget if err := rows.Scan(&siteID, &target.Path, &target.Mode, &target.LastSizeByte, &target.LastScanAt, &target.LastError); err != nil { return nil, err } out[siteID] = append(out[siteID], target) } if err := rows.Err(); err != nil { return nil, err } return out, nil } func (s *Store) targetsBySiteID(ctx context.Context, siteID int64) ([]SiteTarget, error) { const q = `SELECT path, mode, last_size_bytes, last_scan_at, last_error FROM site_targets WHERE site_id = ? ORDER BY id ASC` rows, err := s.db.QueryContext(ctx, q, siteID) if err != nil { return nil, err } defer rows.Close() var out []SiteTarget for rows.Next() { var target SiteTarget if err := rows.Scan(&target.Path, &target.Mode, &target.LastSizeByte, &target.LastScanAt, &target.LastError); err != nil { return nil, err } out = append(out, target) } if err := rows.Err(); err != nil { return nil, err } return out, nil } func boolToInt(v bool) int { if v { return 1 } return 0 } var ErrUsernameTaken = errors.New("username already registered") func isUniqueUsernameErr(err error) bool { return strings.Contains(err.Error(), "UNIQUE constraint failed: users.username") } func nullInt64Arg(v sql.NullInt64) any { if v.Valid { return v.Int64 } return nil } func nullStringArg(v sql.NullString) any { if v.Valid { return v.String } return nil } func timeOrNil(v sql.NullTime) any { if v.Valid { return v.Time.UTC().Format(time.RFC3339) } return nil }