package store import ( "context" "database/sql" "errors" "fmt" "strings" "time" _ "modernc.org/sqlite" ) type Store struct { db *sql.DB } type User struct { ID int64 Username string PasswordHash string IsAdmin bool CreatedAt time.Time } type Site struct { ID int64 SSHUser string Host string RemotePath string CreatedAt time.Time LastRunStatus sql.NullString LastRunOutput sql.NullString LastRunAt sql.NullTime } func Open(path string) (*Store, error) { dsn := fmt.Sprintf("file:%s?_pragma=foreign_keys(1)", path) db, err := sql.Open("sqlite", dsn) if err != nil { return nil, err } if err := db.Ping(); err != nil { return nil, err } s := &Store{db: db} if err := s.migrate(context.Background()); err != nil { return nil, err } return s, nil } func (s *Store) Close() error { return s.db.Close() } func (s *Store) migrate(ctx context.Context) error { const usersSQL = ` CREATE TABLE IF NOT EXISTS users ( id INTEGER PRIMARY KEY AUTOINCREMENT, username TEXT NOT NULL UNIQUE, password_hash TEXT NOT NULL, is_admin INTEGER NOT NULL DEFAULT 0, created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP );` const sessionsSQL = ` CREATE TABLE IF NOT EXISTS sessions ( id INTEGER PRIMARY KEY AUTOINCREMENT, user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE, token_hash TEXT NOT NULL UNIQUE, created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, expires_at DATETIME NOT NULL );` const sitesSQL = ` CREATE TABLE IF NOT EXISTS sites ( id INTEGER PRIMARY KEY AUTOINCREMENT, ssh_user TEXT NOT NULL, host TEXT NOT NULL, remote_path TEXT NOT NULL, created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, last_run_status TEXT, last_run_output TEXT, last_run_at DATETIME );` if _, err := s.db.ExecContext(ctx, usersSQL); err != nil { return err } if err := s.migrateUsersLegacyEmail(ctx); err != nil { return err } if _, err := s.db.ExecContext(ctx, sessionsSQL); err != nil { return err } if _, err := s.db.ExecContext(ctx, sitesSQL); err != nil { return err } return nil } func (s *Store) migrateUsersLegacyEmail(ctx context.Context) error { cols, err := tableColumns(ctx, s.db, "users") if err != nil { return err } if cols["username"] { return nil } if _, err := s.db.ExecContext(ctx, `ALTER TABLE users ADD COLUMN username TEXT`); err != nil { return err } if cols["email"] { if _, err := s.db.ExecContext(ctx, `UPDATE users SET username = lower(trim(email)) WHERE username IS NULL`); err != nil { return err } } if _, err := s.db.ExecContext(ctx, `CREATE UNIQUE INDEX IF NOT EXISTS idx_users_username ON users(username)`); err != nil { return err } return nil } func (s *Store) CreateUser(ctx context.Context, username, passwordHash string) (User, error) { tx, err := s.db.BeginTx(ctx, nil) if err != nil { return User{}, err } defer tx.Rollback() var count int if err := tx.QueryRowContext(ctx, `SELECT COUNT(1) FROM users`).Scan(&count); err != nil { return User{}, err } isAdmin := count == 0 res, err := tx.ExecContext( ctx, `INSERT INTO users (username, password_hash, is_admin) VALUES (?, ?, ?)`, username, passwordHash, boolToInt(isAdmin), ) if err != nil { if isUniqueUsernameErr(err) { return User{}, ErrUsernameTaken } return User{}, err } userID, err := res.LastInsertId() if err != nil { return User{}, err } user, err := userByIDTx(ctx, tx, userID) if err != nil { return User{}, err } if err := tx.Commit(); err != nil { return User{}, err } return user, nil } func (s *Store) UserByUsername(ctx context.Context, username string) (User, error) { const q = ` SELECT id, username, password_hash, is_admin, created_at FROM users WHERE username = ?` return scanUser(s.db.QueryRowContext(ctx, q, username)) } func (s *Store) UserBySessionTokenHash(ctx context.Context, tokenHash string) (User, error) { const q = ` SELECT u.id, u.username, u.password_hash, u.is_admin, u.created_at FROM sessions s JOIN users u ON u.id = s.user_id WHERE s.token_hash = ? AND s.expires_at > CURRENT_TIMESTAMP` return scanUser(s.db.QueryRowContext(ctx, q, tokenHash)) } func (s *Store) CreateSession(ctx context.Context, userID int64, tokenHash string, expiresAt time.Time) error { _, err := s.db.ExecContext( ctx, `INSERT INTO sessions (user_id, token_hash, expires_at) VALUES (?, ?, ?)`, userID, tokenHash, expiresAt.UTC().Format(time.RFC3339), ) return err } func (s *Store) DeleteSessionByTokenHash(ctx context.Context, tokenHash string) error { _, err := s.db.ExecContext(ctx, `DELETE FROM sessions WHERE token_hash = ?`, tokenHash) return err } func (s *Store) UpdateUserPasswordHash(ctx context.Context, userID int64, passwordHash string) error { _, err := s.db.ExecContext(ctx, `UPDATE users SET password_hash = ? WHERE id = ?`, passwordHash, userID) return err } func (s *Store) CreateSite(ctx context.Context, sshUser, host, remotePath string) (Site, error) { res, err := s.db.ExecContext( ctx, `INSERT INTO sites (ssh_user, host, remote_path) VALUES (?, ?, ?)`, sshUser, host, remotePath, ) if err != nil { return Site{}, err } id, err := res.LastInsertId() if err != nil { return Site{}, err } return s.SiteByID(ctx, id) } func (s *Store) ListSites(ctx context.Context) ([]Site, error) { const q = ` SELECT id, ssh_user, host, remote_path, created_at, last_run_status, last_run_output, last_run_at FROM sites ORDER BY id DESC` rows, err := s.db.QueryContext(ctx, q) if err != nil { return nil, err } defer rows.Close() var out []Site for rows.Next() { site, err := scanSite(rows) if err != nil { return nil, err } out = append(out, site) } if err := rows.Err(); err != nil { return nil, err } return out, nil } func (s *Store) SiteByID(ctx context.Context, id int64) (Site, error) { const q = ` SELECT id, ssh_user, host, remote_path, created_at, last_run_status, last_run_output, last_run_at FROM sites WHERE id = ?` return scanSite(s.db.QueryRowContext(ctx, q, id)) } func (s *Store) UpdateSiteRunResult(ctx context.Context, id int64, status, output string, at time.Time) error { _, err := s.db.ExecContext( ctx, `UPDATE sites SET last_run_status = ?, last_run_output = ?, last_run_at = ? WHERE id = ?`, status, output, at.UTC().Format(time.RFC3339), id, ) return err } func userByIDTx(ctx context.Context, tx *sql.Tx, id int64) (User, error) { const q = ` SELECT id, username, password_hash, is_admin, created_at FROM users WHERE id = ?` return scanUser(tx.QueryRowContext(ctx, q, id)) } type scanner interface { Scan(dest ...any) error } func scanUser(row scanner) (User, error) { var user User var isAdmin int if err := row.Scan(&user.ID, &user.Username, &user.PasswordHash, &isAdmin, &user.CreatedAt); err != nil { return User{}, err } user.IsAdmin = isAdmin == 1 return user, nil } func scanSite(row scanner) (Site, error) { var site Site if err := row.Scan( &site.ID, &site.SSHUser, &site.Host, &site.RemotePath, &site.CreatedAt, &site.LastRunStatus, &site.LastRunOutput, &site.LastRunAt, ); err != nil { return Site{}, err } return site, nil } func boolToInt(v bool) int { if v { return 1 } return 0 } var ErrUsernameTaken = errors.New("username already registered") func isUniqueUsernameErr(err error) bool { return strings.Contains(err.Error(), "UNIQUE constraint failed: users.username") } func tableColumns(ctx context.Context, db *sql.DB, table string) (map[string]bool, error) { rows, err := db.QueryContext(ctx, fmt.Sprintf("PRAGMA table_info(%s)", table)) if err != nil { return nil, err } defer rows.Close() cols := map[string]bool{} for rows.Next() { var cid int var name string var typ string var notNull int var dflt sql.NullString var pk int if err := rows.Scan(&cid, &name, &typ, ¬Null, &dflt, &pk); err != nil { return nil, err } cols[name] = true } if err := rows.Err(); err != nil { return nil, err } return cols, nil }