351 lines
7.9 KiB
Go
351 lines
7.9 KiB
Go
package store
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"errors"
|
|
"fmt"
|
|
"strings"
|
|
"time"
|
|
|
|
_ "modernc.org/sqlite"
|
|
)
|
|
|
|
type Store struct {
|
|
db *sql.DB
|
|
}
|
|
|
|
type User struct {
|
|
ID int64
|
|
Username string
|
|
PasswordHash string
|
|
IsAdmin bool
|
|
CreatedAt time.Time
|
|
}
|
|
|
|
type Site struct {
|
|
ID int64
|
|
SSHUser string
|
|
Host string
|
|
RemotePath string
|
|
CreatedAt time.Time
|
|
LastRunStatus sql.NullString
|
|
LastRunOutput sql.NullString
|
|
LastRunAt sql.NullTime
|
|
}
|
|
|
|
func Open(path string) (*Store, error) {
|
|
dsn := fmt.Sprintf("file:%s?_pragma=foreign_keys(1)", path)
|
|
db, err := sql.Open("sqlite", dsn)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if err := db.Ping(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
s := &Store{db: db}
|
|
if err := s.migrate(context.Background()); err != nil {
|
|
return nil, err
|
|
}
|
|
return s, nil
|
|
}
|
|
|
|
func (s *Store) Close() error {
|
|
return s.db.Close()
|
|
}
|
|
|
|
func (s *Store) migrate(ctx context.Context) error {
|
|
const usersSQL = `
|
|
CREATE TABLE IF NOT EXISTS users (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
username TEXT NOT NULL UNIQUE,
|
|
password_hash TEXT NOT NULL,
|
|
is_admin INTEGER NOT NULL DEFAULT 0,
|
|
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
|
|
);`
|
|
|
|
const sessionsSQL = `
|
|
CREATE TABLE IF NOT EXISTS sessions (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
|
token_hash TEXT NOT NULL UNIQUE,
|
|
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
|
expires_at DATETIME NOT NULL
|
|
);`
|
|
|
|
const sitesSQL = `
|
|
CREATE TABLE IF NOT EXISTS sites (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
ssh_user TEXT NOT NULL,
|
|
host TEXT NOT NULL,
|
|
remote_path TEXT NOT NULL,
|
|
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
|
last_run_status TEXT,
|
|
last_run_output TEXT,
|
|
last_run_at DATETIME
|
|
);`
|
|
|
|
if _, err := s.db.ExecContext(ctx, usersSQL); err != nil {
|
|
return err
|
|
}
|
|
if err := s.migrateUsersLegacyEmail(ctx); err != nil {
|
|
return err
|
|
}
|
|
if _, err := s.db.ExecContext(ctx, sessionsSQL); err != nil {
|
|
return err
|
|
}
|
|
if _, err := s.db.ExecContext(ctx, sitesSQL); err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *Store) migrateUsersLegacyEmail(ctx context.Context) error {
|
|
cols, err := tableColumns(ctx, s.db, "users")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if cols["username"] {
|
|
return nil
|
|
}
|
|
|
|
if _, err := s.db.ExecContext(ctx, `ALTER TABLE users ADD COLUMN username TEXT`); err != nil {
|
|
return err
|
|
}
|
|
|
|
if cols["email"] {
|
|
if _, err := s.db.ExecContext(ctx, `UPDATE users SET username = lower(trim(email)) WHERE username IS NULL`); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
if _, err := s.db.ExecContext(ctx, `CREATE UNIQUE INDEX IF NOT EXISTS idx_users_username ON users(username)`); err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *Store) CreateUser(ctx context.Context, username, passwordHash string) (User, error) {
|
|
tx, err := s.db.BeginTx(ctx, nil)
|
|
if err != nil {
|
|
return User{}, err
|
|
}
|
|
defer tx.Rollback()
|
|
|
|
var count int
|
|
if err := tx.QueryRowContext(ctx, `SELECT COUNT(1) FROM users`).Scan(&count); err != nil {
|
|
return User{}, err
|
|
}
|
|
isAdmin := count == 0
|
|
|
|
res, err := tx.ExecContext(
|
|
ctx,
|
|
`INSERT INTO users (username, password_hash, is_admin) VALUES (?, ?, ?)`,
|
|
username,
|
|
passwordHash,
|
|
boolToInt(isAdmin),
|
|
)
|
|
if err != nil {
|
|
if isUniqueUsernameErr(err) {
|
|
return User{}, ErrUsernameTaken
|
|
}
|
|
return User{}, err
|
|
}
|
|
|
|
userID, err := res.LastInsertId()
|
|
if err != nil {
|
|
return User{}, err
|
|
}
|
|
|
|
user, err := userByIDTx(ctx, tx, userID)
|
|
if err != nil {
|
|
return User{}, err
|
|
}
|
|
|
|
if err := tx.Commit(); err != nil {
|
|
return User{}, err
|
|
}
|
|
return user, nil
|
|
}
|
|
|
|
func (s *Store) UserByUsername(ctx context.Context, username string) (User, error) {
|
|
const q = `
|
|
SELECT id, username, password_hash, is_admin, created_at
|
|
FROM users
|
|
WHERE username = ?`
|
|
return scanUser(s.db.QueryRowContext(ctx, q, username))
|
|
}
|
|
|
|
func (s *Store) UserBySessionTokenHash(ctx context.Context, tokenHash string) (User, error) {
|
|
const q = `
|
|
SELECT u.id, u.username, u.password_hash, u.is_admin, u.created_at
|
|
FROM sessions s
|
|
JOIN users u ON u.id = s.user_id
|
|
WHERE s.token_hash = ? AND s.expires_at > CURRENT_TIMESTAMP`
|
|
return scanUser(s.db.QueryRowContext(ctx, q, tokenHash))
|
|
}
|
|
|
|
func (s *Store) CreateSession(ctx context.Context, userID int64, tokenHash string, expiresAt time.Time) error {
|
|
_, err := s.db.ExecContext(
|
|
ctx,
|
|
`INSERT INTO sessions (user_id, token_hash, expires_at) VALUES (?, ?, ?)`,
|
|
userID,
|
|
tokenHash,
|
|
expiresAt.UTC().Format(time.RFC3339),
|
|
)
|
|
return err
|
|
}
|
|
|
|
func (s *Store) DeleteSessionByTokenHash(ctx context.Context, tokenHash string) error {
|
|
_, err := s.db.ExecContext(ctx, `DELETE FROM sessions WHERE token_hash = ?`, tokenHash)
|
|
return err
|
|
}
|
|
|
|
func (s *Store) UpdateUserPasswordHash(ctx context.Context, userID int64, passwordHash string) error {
|
|
_, err := s.db.ExecContext(ctx, `UPDATE users SET password_hash = ? WHERE id = ?`, passwordHash, userID)
|
|
return err
|
|
}
|
|
|
|
func (s *Store) CreateSite(ctx context.Context, sshUser, host, remotePath string) (Site, error) {
|
|
res, err := s.db.ExecContext(
|
|
ctx,
|
|
`INSERT INTO sites (ssh_user, host, remote_path) VALUES (?, ?, ?)`,
|
|
sshUser,
|
|
host,
|
|
remotePath,
|
|
)
|
|
if err != nil {
|
|
return Site{}, err
|
|
}
|
|
id, err := res.LastInsertId()
|
|
if err != nil {
|
|
return Site{}, err
|
|
}
|
|
return s.SiteByID(ctx, id)
|
|
}
|
|
|
|
func (s *Store) ListSites(ctx context.Context) ([]Site, error) {
|
|
const q = `
|
|
SELECT id, ssh_user, host, remote_path, created_at, last_run_status, last_run_output, last_run_at
|
|
FROM sites
|
|
ORDER BY id DESC`
|
|
rows, err := s.db.QueryContext(ctx, q)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
var out []Site
|
|
for rows.Next() {
|
|
site, err := scanSite(rows)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
out = append(out, site)
|
|
}
|
|
if err := rows.Err(); err != nil {
|
|
return nil, err
|
|
}
|
|
return out, nil
|
|
}
|
|
|
|
func (s *Store) SiteByID(ctx context.Context, id int64) (Site, error) {
|
|
const q = `
|
|
SELECT id, ssh_user, host, remote_path, created_at, last_run_status, last_run_output, last_run_at
|
|
FROM sites
|
|
WHERE id = ?`
|
|
return scanSite(s.db.QueryRowContext(ctx, q, id))
|
|
}
|
|
|
|
func (s *Store) UpdateSiteRunResult(ctx context.Context, id int64, status, output string, at time.Time) error {
|
|
_, err := s.db.ExecContext(
|
|
ctx,
|
|
`UPDATE sites SET last_run_status = ?, last_run_output = ?, last_run_at = ? WHERE id = ?`,
|
|
status,
|
|
output,
|
|
at.UTC().Format(time.RFC3339),
|
|
id,
|
|
)
|
|
return err
|
|
}
|
|
|
|
func userByIDTx(ctx context.Context, tx *sql.Tx, id int64) (User, error) {
|
|
const q = `
|
|
SELECT id, username, password_hash, is_admin, created_at
|
|
FROM users
|
|
WHERE id = ?`
|
|
return scanUser(tx.QueryRowContext(ctx, q, id))
|
|
}
|
|
|
|
type scanner interface {
|
|
Scan(dest ...any) error
|
|
}
|
|
|
|
func scanUser(row scanner) (User, error) {
|
|
var user User
|
|
var isAdmin int
|
|
if err := row.Scan(&user.ID, &user.Username, &user.PasswordHash, &isAdmin, &user.CreatedAt); err != nil {
|
|
return User{}, err
|
|
}
|
|
user.IsAdmin = isAdmin == 1
|
|
return user, nil
|
|
}
|
|
|
|
func scanSite(row scanner) (Site, error) {
|
|
var site Site
|
|
if err := row.Scan(
|
|
&site.ID,
|
|
&site.SSHUser,
|
|
&site.Host,
|
|
&site.RemotePath,
|
|
&site.CreatedAt,
|
|
&site.LastRunStatus,
|
|
&site.LastRunOutput,
|
|
&site.LastRunAt,
|
|
); err != nil {
|
|
return Site{}, err
|
|
}
|
|
return site, nil
|
|
}
|
|
|
|
func boolToInt(v bool) int {
|
|
if v {
|
|
return 1
|
|
}
|
|
return 0
|
|
}
|
|
|
|
var ErrUsernameTaken = errors.New("username already registered")
|
|
|
|
func isUniqueUsernameErr(err error) bool {
|
|
return strings.Contains(err.Error(), "UNIQUE constraint failed: users.username")
|
|
}
|
|
|
|
func tableColumns(ctx context.Context, db *sql.DB, table string) (map[string]bool, error) {
|
|
rows, err := db.QueryContext(ctx, fmt.Sprintf("PRAGMA table_info(%s)", table))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
cols := map[string]bool{}
|
|
for rows.Next() {
|
|
var cid int
|
|
var name string
|
|
var typ string
|
|
var notNull int
|
|
var dflt sql.NullString
|
|
var pk int
|
|
if err := rows.Scan(&cid, &name, &typ, ¬Null, &dflt, &pk); err != nil {
|
|
return nil, err
|
|
}
|
|
cols[name] = true
|
|
}
|
|
if err := rows.Err(); err != nil {
|
|
return nil, err
|
|
}
|
|
return cols, nil
|
|
}
|