satoru/internal/store/store.go

728 lines
17 KiB
Go

package store
import (
"context"
"database/sql"
"errors"
"fmt"
"strings"
"time"
_ "modernc.org/sqlite"
)
type Store struct {
db *sql.DB
}
type User struct {
ID int64
Username string
PasswordHash string
IsAdmin bool
CreatedAt time.Time
}
type Site struct {
ID int64
SSHUser string
Host string
Port int
CreatedAt time.Time
LastRunStatus sql.NullString
LastRunOutput sql.NullString
LastRunAt sql.NullTime
LastScanAt sql.NullTime
LastScanState sql.NullString
LastScanNotes sql.NullString
Targets []SiteTarget
}
type SiteTarget struct {
Path string
Mode string
LastSizeByte sql.NullInt64
LastScanAt sql.NullTime
LastError sql.NullString
}
type Job struct {
ID int64
SiteID int64
Type string
Status string
Summary sql.NullString
CreatedAt time.Time
StartedAt sql.NullTime
FinishedAt sql.NullTime
}
type JobEvent struct {
JobID int64
Level string
Message string
OccurredAt time.Time
}
func Open(path string) (*Store, error) {
dsn := fmt.Sprintf("file:%s?_pragma=foreign_keys(1)", path)
db, err := sql.Open("sqlite", dsn)
if err != nil {
return nil, err
}
if err := db.Ping(); err != nil {
return nil, err
}
s := &Store{db: db}
if err := s.migrate(context.Background()); err != nil {
return nil, err
}
return s, nil
}
func (s *Store) Close() error {
return s.db.Close()
}
func (s *Store) migrate(ctx context.Context) error {
const usersSQL = `
CREATE TABLE IF NOT EXISTS users (
id INTEGER PRIMARY KEY AUTOINCREMENT,
username TEXT NOT NULL UNIQUE,
password_hash TEXT NOT NULL,
is_admin INTEGER NOT NULL DEFAULT 0,
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
);`
const sessionsSQL = `
CREATE TABLE IF NOT EXISTS sessions (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE,
token_hash TEXT NOT NULL UNIQUE,
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
expires_at DATETIME NOT NULL
);`
const sitesSQL = `
CREATE TABLE IF NOT EXISTS sites (
id INTEGER PRIMARY KEY AUTOINCREMENT,
ssh_user TEXT NOT NULL,
host TEXT NOT NULL,
port INTEGER NOT NULL DEFAULT 22,
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
last_run_status TEXT,
last_run_output TEXT,
last_run_at DATETIME,
last_scan_at DATETIME,
last_scan_state TEXT,
last_scan_notes TEXT
);`
const siteTargetsSQL = `
CREATE TABLE IF NOT EXISTS site_targets (
id INTEGER PRIMARY KEY AUTOINCREMENT,
site_id INTEGER NOT NULL REFERENCES sites(id) ON DELETE CASCADE,
path TEXT NOT NULL,
mode TEXT NOT NULL CHECK(mode IN ('directory', 'sqlite_dump')),
last_size_bytes INTEGER,
last_scan_at DATETIME,
last_error TEXT,
UNIQUE(site_id, path, mode)
);`
const jobsSQL = `
CREATE TABLE IF NOT EXISTS jobs (
id INTEGER PRIMARY KEY AUTOINCREMENT,
site_id INTEGER NOT NULL REFERENCES sites(id) ON DELETE CASCADE,
type TEXT NOT NULL,
status TEXT NOT NULL CHECK(status IN ('queued', 'running', 'success', 'warning', 'failed')),
summary TEXT,
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
started_at DATETIME,
finished_at DATETIME
);`
const jobEventsSQL = `
CREATE TABLE IF NOT EXISTS job_events (
id INTEGER PRIMARY KEY AUTOINCREMENT,
job_id INTEGER NOT NULL REFERENCES jobs(id) ON DELETE CASCADE,
level TEXT NOT NULL,
message TEXT NOT NULL,
occurred_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
);`
if _, err := s.db.ExecContext(ctx, usersSQL); err != nil {
return err
}
if _, err := s.db.ExecContext(ctx, sessionsSQL); err != nil {
return err
}
if _, err := s.db.ExecContext(ctx, sitesSQL); err != nil {
return err
}
if _, err := s.db.ExecContext(ctx, siteTargetsSQL); err != nil {
return err
}
if _, err := s.db.ExecContext(ctx, jobsSQL); err != nil {
return err
}
if _, err := s.db.ExecContext(ctx, jobEventsSQL); err != nil {
return err
}
return nil
}
func (s *Store) CreateUser(ctx context.Context, username, passwordHash string) (User, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return User{}, err
}
defer tx.Rollback()
var count int
if err := tx.QueryRowContext(ctx, `SELECT COUNT(1) FROM users`).Scan(&count); err != nil {
return User{}, err
}
isAdmin := count == 0
res, err := tx.ExecContext(
ctx,
`INSERT INTO users (username, password_hash, is_admin) VALUES (?, ?, ?)`,
username,
passwordHash,
boolToInt(isAdmin),
)
if err != nil {
if isUniqueUsernameErr(err) {
return User{}, ErrUsernameTaken
}
return User{}, err
}
userID, err := res.LastInsertId()
if err != nil {
return User{}, err
}
user, err := userByIDTx(ctx, tx, userID)
if err != nil {
return User{}, err
}
if err := tx.Commit(); err != nil {
return User{}, err
}
return user, nil
}
func (s *Store) UserByUsername(ctx context.Context, username string) (User, error) {
const q = `
SELECT id, username, password_hash, is_admin, created_at
FROM users
WHERE username = ?`
return scanUser(s.db.QueryRowContext(ctx, q, username))
}
func (s *Store) UserBySessionTokenHash(ctx context.Context, tokenHash string) (User, error) {
const q = `
SELECT u.id, u.username, u.password_hash, u.is_admin, u.created_at
FROM sessions s
JOIN users u ON u.id = s.user_id
WHERE s.token_hash = ? AND s.expires_at > CURRENT_TIMESTAMP`
return scanUser(s.db.QueryRowContext(ctx, q, tokenHash))
}
func (s *Store) CreateSession(ctx context.Context, userID int64, tokenHash string, expiresAt time.Time) error {
_, err := s.db.ExecContext(
ctx,
`INSERT INTO sessions (user_id, token_hash, expires_at) VALUES (?, ?, ?)`,
userID,
tokenHash,
expiresAt.UTC().Format(time.RFC3339),
)
return err
}
func (s *Store) DeleteSessionByTokenHash(ctx context.Context, tokenHash string) error {
_, err := s.db.ExecContext(ctx, `DELETE FROM sessions WHERE token_hash = ?`, tokenHash)
return err
}
func (s *Store) TouchSessionByTokenHash(ctx context.Context, tokenHash string, expiresAt time.Time) error {
res, err := s.db.ExecContext(
ctx,
`UPDATE sessions SET expires_at = ? WHERE token_hash = ? AND expires_at > CURRENT_TIMESTAMP`,
expiresAt.UTC().Format(time.RFC3339),
tokenHash,
)
if err != nil {
return err
}
rows, err := res.RowsAffected()
if err != nil {
return err
}
if rows == 0 {
return sql.ErrNoRows
}
return nil
}
func (s *Store) UpdateUserPasswordHash(ctx context.Context, userID int64, passwordHash string) error {
_, err := s.db.ExecContext(ctx, `UPDATE users SET password_hash = ? WHERE id = ?`, passwordHash, userID)
return err
}
func (s *Store) CreateSite(ctx context.Context, sshUser, host string, port int, targets []SiteTarget) (Site, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return Site{}, err
}
defer tx.Rollback()
res, err := tx.ExecContext(ctx, `INSERT INTO sites (ssh_user, host, port) VALUES (?, ?, ?)`, sshUser, host, port)
if err != nil {
return Site{}, err
}
id, err := res.LastInsertId()
if err != nil {
return Site{}, err
}
for _, t := range targets {
if _, err := tx.ExecContext(
ctx,
`INSERT INTO site_targets (site_id, path, mode) VALUES (?, ?, ?)`,
id,
t.Path,
t.Mode,
); err != nil {
return Site{}, err
}
}
if err := tx.Commit(); err != nil {
return Site{}, err
}
return s.SiteByID(ctx, id)
}
func (s *Store) UpdateSite(ctx context.Context, id int64, sshUser, host string, port int, targets []SiteTarget) (Site, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return Site{}, err
}
defer tx.Rollback()
res, err := tx.ExecContext(ctx, `UPDATE sites SET ssh_user = ?, host = ?, port = ? WHERE id = ?`, sshUser, host, port, id)
if err != nil {
return Site{}, err
}
affected, err := res.RowsAffected()
if err != nil {
return Site{}, err
}
if affected == 0 {
return Site{}, sql.ErrNoRows
}
if _, err := tx.ExecContext(ctx, `DELETE FROM site_targets WHERE site_id = ?`, id); err != nil {
return Site{}, err
}
for _, t := range targets {
if _, err := tx.ExecContext(
ctx,
`INSERT INTO site_targets (site_id, path, mode) VALUES (?, ?, ?)`,
id,
t.Path,
t.Mode,
); err != nil {
return Site{}, err
}
}
if err := tx.Commit(); err != nil {
return Site{}, err
}
return s.SiteByID(ctx, id)
}
func (s *Store) DeleteSite(ctx context.Context, id int64) error {
res, err := s.db.ExecContext(ctx, `DELETE FROM sites WHERE id = ?`, id)
if err != nil {
return err
}
affected, err := res.RowsAffected()
if err != nil {
return err
}
if affected == 0 {
return sql.ErrNoRows
}
return nil
}
func (s *Store) ListSites(ctx context.Context) ([]Site, error) {
const q = `
SELECT id, ssh_user, host, port, created_at, last_run_status, last_run_output, last_run_at, last_scan_at, last_scan_state, last_scan_notes
FROM sites
ORDER BY id DESC`
rows, err := s.db.QueryContext(ctx, q)
if err != nil {
return nil, err
}
defer rows.Close()
var out []Site
for rows.Next() {
site, err := scanSite(rows)
if err != nil {
return nil, err
}
out = append(out, site)
}
if err := rows.Err(); err != nil {
return nil, err
}
if err := s.populateTargets(ctx, out); err != nil {
return nil, err
}
return out, nil
}
func (s *Store) SiteByID(ctx context.Context, id int64) (Site, error) {
const q = `
SELECT id, ssh_user, host, port, created_at, last_run_status, last_run_output, last_run_at, last_scan_at, last_scan_state, last_scan_notes
FROM sites
WHERE id = ?`
site, err := scanSite(s.db.QueryRowContext(ctx, q, id))
if err != nil {
return Site{}, err
}
targets, err := s.targetsBySiteID(ctx, id)
if err != nil {
return Site{}, err
}
site.Targets = targets
return site, nil
}
func (s *Store) UpdateSiteRunResult(ctx context.Context, id int64, status, output string, at time.Time) error {
_, err := s.db.ExecContext(
ctx,
`UPDATE sites SET last_run_status = ?, last_run_output = ?, last_run_at = ? WHERE id = ?`,
status,
output,
at.UTC().Format(time.RFC3339),
id,
)
return err
}
func (s *Store) CreateJob(ctx context.Context, siteID int64, jobType string) (Job, error) {
res, err := s.db.ExecContext(
ctx,
`INSERT INTO jobs (site_id, type, status) VALUES (?, ?, 'queued')`,
siteID,
jobType,
)
if err != nil {
return Job{}, err
}
id, err := res.LastInsertId()
if err != nil {
return Job{}, err
}
return s.JobByID(ctx, id)
}
func (s *Store) JobByID(ctx context.Context, id int64) (Job, error) {
const q = `
SELECT id, site_id, type, status, summary, created_at, started_at, finished_at
FROM jobs
WHERE id = ?`
return scanJob(s.db.QueryRowContext(ctx, q, id))
}
func (s *Store) TryStartNextQueuedJob(ctx context.Context) (Job, bool, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return Job{}, false, err
}
defer tx.Rollback()
var id int64
if err := tx.QueryRowContext(ctx, `SELECT id FROM jobs WHERE status = 'queued' ORDER BY id ASC LIMIT 1`).Scan(&id); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return Job{}, false, nil
}
return Job{}, false, err
}
res, err := tx.ExecContext(ctx, `UPDATE jobs SET status = 'running', started_at = ? WHERE id = ? AND status = 'queued'`, time.Now().UTC().Format(time.RFC3339), id)
if err != nil {
return Job{}, false, err
}
affected, err := res.RowsAffected()
if err != nil {
return Job{}, false, err
}
if affected == 0 {
return Job{}, false, nil
}
job, err := scanJob(tx.QueryRowContext(ctx, `
SELECT id, site_id, type, status, summary, created_at, started_at, finished_at
FROM jobs
WHERE id = ?`, id))
if err != nil {
return Job{}, false, err
}
if err := tx.Commit(); err != nil {
return Job{}, false, err
}
return job, true, nil
}
func (s *Store) CompleteJob(ctx context.Context, jobID int64, status, summary string) error {
_, err := s.db.ExecContext(
ctx,
`UPDATE jobs SET status = ?, summary = ?, finished_at = ? WHERE id = ?`,
status,
summary,
time.Now().UTC().Format(time.RFC3339),
jobID,
)
return err
}
func (s *Store) AddJobEvent(ctx context.Context, event JobEvent) error {
_, err := s.db.ExecContext(
ctx,
`INSERT INTO job_events (job_id, level, message) VALUES (?, ?, ?)`,
event.JobID,
event.Level,
event.Message,
)
return err
}
func (s *Store) ListRecentJobs(ctx context.Context, limit int) ([]Job, error) {
if limit <= 0 {
limit = 20
}
rows, err := s.db.QueryContext(ctx, `
SELECT id, site_id, type, status, summary, created_at, started_at, finished_at
FROM jobs
ORDER BY id DESC
LIMIT ?`, limit)
if err != nil {
return nil, err
}
defer rows.Close()
var out []Job
for rows.Next() {
job, err := scanJob(rows)
if err != nil {
return nil, err
}
out = append(out, job)
}
if err := rows.Err(); err != nil {
return nil, err
}
return out, nil
}
func (s *Store) UpdateSiteScanResult(ctx context.Context, siteID int64, state, notes string, scannedAt time.Time, targets []SiteTarget) error {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return err
}
defer tx.Rollback()
for _, t := range targets {
if _, err := tx.ExecContext(
ctx,
`UPDATE site_targets SET last_size_bytes = ?, last_scan_at = ?, last_error = ? WHERE site_id = ? AND path = ? AND mode = ?`,
nullInt64Arg(t.LastSizeByte),
timeOrNil(t.LastScanAt),
nullStringArg(t.LastError),
siteID,
t.Path,
t.Mode,
); err != nil {
return err
}
}
if _, err := tx.ExecContext(
ctx,
`UPDATE sites SET last_scan_at = ?, last_scan_state = ?, last_scan_notes = ? WHERE id = ?`,
scannedAt.UTC().Format(time.RFC3339),
state,
notes,
siteID,
); err != nil {
return err
}
return tx.Commit()
}
func userByIDTx(ctx context.Context, tx *sql.Tx, id int64) (User, error) {
const q = `
SELECT id, username, password_hash, is_admin, created_at
FROM users
WHERE id = ?`
return scanUser(tx.QueryRowContext(ctx, q, id))
}
type scanner interface {
Scan(dest ...any) error
}
func scanJob(row scanner) (Job, error) {
var job Job
if err := row.Scan(
&job.ID,
&job.SiteID,
&job.Type,
&job.Status,
&job.Summary,
&job.CreatedAt,
&job.StartedAt,
&job.FinishedAt,
); err != nil {
return Job{}, err
}
return job, nil
}
func scanUser(row scanner) (User, error) {
var user User
var isAdmin int
if err := row.Scan(&user.ID, &user.Username, &user.PasswordHash, &isAdmin, &user.CreatedAt); err != nil {
return User{}, err
}
user.IsAdmin = isAdmin == 1
return user, nil
}
func scanSite(row scanner) (Site, error) {
var site Site
if err := row.Scan(
&site.ID,
&site.SSHUser,
&site.Host,
&site.Port,
&site.CreatedAt,
&site.LastRunStatus,
&site.LastRunOutput,
&site.LastRunAt,
&site.LastScanAt,
&site.LastScanState,
&site.LastScanNotes,
); err != nil {
return Site{}, err
}
return site, nil
}
func (s *Store) populateTargets(ctx context.Context, sites []Site) error {
if len(sites) == 0 {
return nil
}
targetsBySite, err := s.allTargetsBySiteID(ctx)
if err != nil {
return err
}
for i := range sites {
sites[i].Targets = targetsBySite[sites[i].ID]
}
return nil
}
func (s *Store) allTargetsBySiteID(ctx context.Context) (map[int64][]SiteTarget, error) {
const q = `SELECT site_id, path, mode, last_size_bytes, last_scan_at, last_error FROM site_targets ORDER BY id ASC`
rows, err := s.db.QueryContext(ctx, q)
if err != nil {
return nil, err
}
defer rows.Close()
out := map[int64][]SiteTarget{}
for rows.Next() {
var siteID int64
var target SiteTarget
if err := rows.Scan(&siteID, &target.Path, &target.Mode, &target.LastSizeByte, &target.LastScanAt, &target.LastError); err != nil {
return nil, err
}
out[siteID] = append(out[siteID], target)
}
if err := rows.Err(); err != nil {
return nil, err
}
return out, nil
}
func (s *Store) targetsBySiteID(ctx context.Context, siteID int64) ([]SiteTarget, error) {
const q = `SELECT path, mode, last_size_bytes, last_scan_at, last_error FROM site_targets WHERE site_id = ? ORDER BY id ASC`
rows, err := s.db.QueryContext(ctx, q, siteID)
if err != nil {
return nil, err
}
defer rows.Close()
var out []SiteTarget
for rows.Next() {
var target SiteTarget
if err := rows.Scan(&target.Path, &target.Mode, &target.LastSizeByte, &target.LastScanAt, &target.LastError); err != nil {
return nil, err
}
out = append(out, target)
}
if err := rows.Err(); err != nil {
return nil, err
}
return out, nil
}
func boolToInt(v bool) int {
if v {
return 1
}
return 0
}
var ErrUsernameTaken = errors.New("username already registered")
func isUniqueUsernameErr(err error) bool {
return strings.Contains(err.Error(), "UNIQUE constraint failed: users.username")
}
func nullInt64Arg(v sql.NullInt64) any {
if v.Valid {
return v.Int64
}
return nil
}
func nullStringArg(v sql.NullString) any {
if v.Valid {
return v.String
}
return nil
}
func timeOrNil(v sql.NullTime) any {
if v.Valid {
return v.Time.UTC().Format(time.RFC3339)
}
return nil
}