889 lines
22 KiB
Go
889 lines
22 KiB
Go
package store
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"database/sql"
|
|
"encoding/hex"
|
|
"errors"
|
|
"fmt"
|
|
"strings"
|
|
"time"
|
|
|
|
"go.uber.org/zap"
|
|
|
|
_ "modernc.org/sqlite"
|
|
)
|
|
|
|
type Store struct {
|
|
db *sql.DB
|
|
log *zap.Logger
|
|
}
|
|
|
|
type User struct {
|
|
ID int64
|
|
Username string
|
|
PasswordHash string
|
|
IsAdmin bool
|
|
CreatedAt time.Time
|
|
}
|
|
|
|
type Site struct {
|
|
ID int64
|
|
SiteUUID string
|
|
SSHUser string
|
|
Host string
|
|
Port int
|
|
CreatedAt time.Time
|
|
LastRunStatus sql.NullString
|
|
LastRunOutput sql.NullString
|
|
LastRunAt sql.NullTime
|
|
LastScanAt sql.NullTime
|
|
LastScanState sql.NullString
|
|
LastScanNotes sql.NullString
|
|
Targets []SiteTarget
|
|
Filters []string
|
|
}
|
|
|
|
type SiteTarget struct {
|
|
ID int64
|
|
Path string
|
|
Mode string
|
|
MySQLHost sql.NullString
|
|
MySQLUser sql.NullString
|
|
MySQLDB sql.NullString
|
|
MySQLPassword sql.NullString
|
|
LastSizeByte sql.NullInt64
|
|
LastScanAt sql.NullTime
|
|
LastError sql.NullString
|
|
}
|
|
|
|
type Job struct {
|
|
ID int64
|
|
SiteID int64
|
|
Type string
|
|
Status string
|
|
Summary sql.NullString
|
|
CreatedAt time.Time
|
|
StartedAt sql.NullTime
|
|
FinishedAt sql.NullTime
|
|
}
|
|
|
|
type JobEvent struct {
|
|
JobID int64
|
|
Level string
|
|
Message string
|
|
OccurredAt time.Time
|
|
}
|
|
|
|
type JobEventRecord struct {
|
|
JobID int64
|
|
SiteID int64
|
|
JobType string
|
|
Level string
|
|
Message string
|
|
OccurredAt time.Time
|
|
}
|
|
|
|
func Open(path string) (*Store, error) {
|
|
dsn := fmt.Sprintf("file:%s?_pragma=foreign_keys(1)&_pragma=journal_mode(WAL)&_pragma=busy_timeout(5000)", path)
|
|
db, err := sql.Open("sqlite", dsn)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
db.SetMaxOpenConns(8)
|
|
db.SetMaxIdleConns(8)
|
|
|
|
if err := db.Ping(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
s := &Store{db: db}
|
|
if err := s.migrate(context.Background()); err != nil {
|
|
return nil, err
|
|
}
|
|
return s, nil
|
|
}
|
|
|
|
func (s *Store) Close() error {
|
|
return s.db.Close()
|
|
}
|
|
|
|
func (s *Store) SetLogger(logger *zap.Logger) {
|
|
s.log = logger
|
|
}
|
|
|
|
func (s *Store) migrate(ctx context.Context) error {
|
|
return s.runMigrations(ctx)
|
|
}
|
|
|
|
func (s *Store) CreateUser(ctx context.Context, username, passwordHash string) (User, error) {
|
|
tx, err := s.db.BeginTx(ctx, nil)
|
|
if err != nil {
|
|
return User{}, err
|
|
}
|
|
defer tx.Rollback()
|
|
|
|
var count int
|
|
if err := tx.QueryRowContext(ctx, `SELECT COUNT(1) FROM users`).Scan(&count); err != nil {
|
|
return User{}, err
|
|
}
|
|
isAdmin := count == 0
|
|
|
|
res, err := tx.ExecContext(
|
|
ctx,
|
|
`INSERT INTO users (username, password_hash, is_admin) VALUES (?, ?, ?)`,
|
|
username,
|
|
passwordHash,
|
|
boolToInt(isAdmin),
|
|
)
|
|
if err != nil {
|
|
if isUniqueUsernameErr(err) {
|
|
return User{}, ErrUsernameTaken
|
|
}
|
|
return User{}, err
|
|
}
|
|
|
|
userID, err := res.LastInsertId()
|
|
if err != nil {
|
|
return User{}, err
|
|
}
|
|
|
|
user, err := userByIDTx(ctx, tx, userID)
|
|
if err != nil {
|
|
return User{}, err
|
|
}
|
|
|
|
if err := tx.Commit(); err != nil {
|
|
return User{}, err
|
|
}
|
|
s.debugDB("user created", zap.Int64("user_id", user.ID), zap.String("username", user.Username), zap.Bool("is_admin", user.IsAdmin))
|
|
return user, nil
|
|
}
|
|
|
|
func (s *Store) UserByUsername(ctx context.Context, username string) (User, error) {
|
|
const q = `
|
|
SELECT id, username, password_hash, is_admin, created_at
|
|
FROM users
|
|
WHERE username = ?`
|
|
return scanUser(s.db.QueryRowContext(ctx, q, username))
|
|
}
|
|
|
|
func (s *Store) UserBySessionTokenHash(ctx context.Context, tokenHash string) (User, error) {
|
|
const q = `
|
|
SELECT u.id, u.username, u.password_hash, u.is_admin, u.created_at
|
|
FROM sessions s
|
|
JOIN users u ON u.id = s.user_id
|
|
WHERE s.token_hash = ? AND s.expires_at > CURRENT_TIMESTAMP`
|
|
return scanUser(s.db.QueryRowContext(ctx, q, tokenHash))
|
|
}
|
|
|
|
func (s *Store) CreateSession(ctx context.Context, userID int64, tokenHash string, expiresAt time.Time) error {
|
|
_, err := s.db.ExecContext(
|
|
ctx,
|
|
`INSERT INTO sessions (user_id, token_hash, expires_at) VALUES (?, ?, ?)`,
|
|
userID,
|
|
tokenHash,
|
|
expiresAt.UTC().Format(time.RFC3339),
|
|
)
|
|
if err == nil {
|
|
s.debugDB("session created", zap.Int64("user_id", userID), zap.Time("expires_at", expiresAt.UTC()))
|
|
}
|
|
return err
|
|
}
|
|
|
|
func (s *Store) DeleteSessionByTokenHash(ctx context.Context, tokenHash string) error {
|
|
_, err := s.db.ExecContext(ctx, `DELETE FROM sessions WHERE token_hash = ?`, tokenHash)
|
|
if err == nil {
|
|
s.debugDB("session deleted")
|
|
}
|
|
return err
|
|
}
|
|
|
|
func (s *Store) TouchSessionByTokenHash(ctx context.Context, tokenHash string, expiresAt time.Time) error {
|
|
res, err := s.db.ExecContext(
|
|
ctx,
|
|
`UPDATE sessions SET expires_at = ? WHERE token_hash = ? AND expires_at > CURRENT_TIMESTAMP`,
|
|
expiresAt.UTC().Format(time.RFC3339),
|
|
tokenHash,
|
|
)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
rows, err := res.RowsAffected()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if rows == 0 {
|
|
return sql.ErrNoRows
|
|
}
|
|
s.debugDB("session touched", zap.Time("expires_at", expiresAt.UTC()))
|
|
return nil
|
|
}
|
|
|
|
func (s *Store) UpdateUserPasswordHash(ctx context.Context, userID int64, passwordHash string) error {
|
|
_, err := s.db.ExecContext(ctx, `UPDATE users SET password_hash = ? WHERE id = ?`, passwordHash, userID)
|
|
if err == nil {
|
|
s.debugDB("user password updated", zap.Int64("user_id", userID))
|
|
}
|
|
return err
|
|
}
|
|
|
|
func (s *Store) CreateSite(ctx context.Context, sshUser, host string, port int, targets []SiteTarget, filters []string) (Site, error) {
|
|
tx, err := s.db.BeginTx(ctx, nil)
|
|
if err != nil {
|
|
return Site{}, err
|
|
}
|
|
defer tx.Rollback()
|
|
|
|
siteUUID, err := newSiteUUID()
|
|
if err != nil {
|
|
return Site{}, err
|
|
}
|
|
res, err := tx.ExecContext(ctx, `INSERT INTO sites (site_uuid, ssh_user, host, port) VALUES (?, ?, ?, ?)`, siteUUID, sshUser, host, port)
|
|
if err != nil {
|
|
return Site{}, err
|
|
}
|
|
id, err := res.LastInsertId()
|
|
if err != nil {
|
|
return Site{}, err
|
|
}
|
|
|
|
for _, t := range targets {
|
|
if _, err := tx.ExecContext(
|
|
ctx,
|
|
`INSERT INTO site_targets (site_id, path, mode, mysql_host, mysql_user, mysql_db, mysql_password) VALUES (?, ?, ?, ?, ?, ?, ?)`,
|
|
id,
|
|
t.Path,
|
|
t.Mode,
|
|
nullStringArg(t.MySQLHost),
|
|
nullStringArg(t.MySQLUser),
|
|
nullStringArg(t.MySQLDB),
|
|
nullStringArg(t.MySQLPassword),
|
|
); err != nil {
|
|
return Site{}, err
|
|
}
|
|
}
|
|
for _, f := range filters {
|
|
if _, err := tx.ExecContext(ctx, `INSERT INTO site_filters (site_id, pattern) VALUES (?, ?)`, id, f); err != nil {
|
|
return Site{}, err
|
|
}
|
|
}
|
|
|
|
if err := tx.Commit(); err != nil {
|
|
return Site{}, err
|
|
}
|
|
s.debugDB("site created", zap.Int64("site_id", id), zap.String("site_uuid", siteUUID), zap.String("ssh_user", sshUser), zap.String("host", host), zap.Int("port", port), zap.Int("targets", len(targets)), zap.Int("filters", len(filters)))
|
|
return s.SiteByID(ctx, id)
|
|
}
|
|
|
|
func (s *Store) UpdateSite(ctx context.Context, id int64, sshUser, host string, port int, targets []SiteTarget, filters []string) (Site, error) {
|
|
tx, err := s.db.BeginTx(ctx, nil)
|
|
if err != nil {
|
|
return Site{}, err
|
|
}
|
|
defer tx.Rollback()
|
|
|
|
res, err := tx.ExecContext(ctx, `UPDATE sites SET ssh_user = ?, host = ?, port = ? WHERE id = ?`, sshUser, host, port, id)
|
|
if err != nil {
|
|
return Site{}, err
|
|
}
|
|
affected, err := res.RowsAffected()
|
|
if err != nil {
|
|
return Site{}, err
|
|
}
|
|
if affected == 0 {
|
|
return Site{}, sql.ErrNoRows
|
|
}
|
|
|
|
if _, err := tx.ExecContext(ctx, `DELETE FROM site_targets WHERE site_id = ?`, id); err != nil {
|
|
return Site{}, err
|
|
}
|
|
for _, t := range targets {
|
|
if _, err := tx.ExecContext(
|
|
ctx,
|
|
`INSERT INTO site_targets (site_id, path, mode, mysql_host, mysql_user, mysql_db, mysql_password) VALUES (?, ?, ?, ?, ?, ?, ?)`,
|
|
id,
|
|
t.Path,
|
|
t.Mode,
|
|
nullStringArg(t.MySQLHost),
|
|
nullStringArg(t.MySQLUser),
|
|
nullStringArg(t.MySQLDB),
|
|
nullStringArg(t.MySQLPassword),
|
|
); err != nil {
|
|
return Site{}, err
|
|
}
|
|
}
|
|
if _, err := tx.ExecContext(ctx, `DELETE FROM site_filters WHERE site_id = ?`, id); err != nil {
|
|
return Site{}, err
|
|
}
|
|
for _, f := range filters {
|
|
if _, err := tx.ExecContext(ctx, `INSERT INTO site_filters (site_id, pattern) VALUES (?, ?)`, id, f); err != nil {
|
|
return Site{}, err
|
|
}
|
|
}
|
|
|
|
if err := tx.Commit(); err != nil {
|
|
return Site{}, err
|
|
}
|
|
s.debugDB("site updated", zap.Int64("site_id", id), zap.String("ssh_user", sshUser), zap.String("host", host), zap.Int("port", port), zap.Int("targets", len(targets)), zap.Int("filters", len(filters)))
|
|
return s.SiteByID(ctx, id)
|
|
}
|
|
|
|
func (s *Store) DeleteSite(ctx context.Context, id int64) error {
|
|
res, err := s.db.ExecContext(ctx, `DELETE FROM sites WHERE id = ?`, id)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
affected, err := res.RowsAffected()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if affected == 0 {
|
|
return sql.ErrNoRows
|
|
}
|
|
s.debugDB("site deleted", zap.Int64("site_id", id))
|
|
return nil
|
|
}
|
|
|
|
func (s *Store) DeleteSiteTarget(ctx context.Context, siteID, targetID int64) error {
|
|
res, err := s.db.ExecContext(ctx, `DELETE FROM site_targets WHERE id = ? AND site_id = ?`, targetID, siteID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
affected, err := res.RowsAffected()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if affected == 0 {
|
|
return sql.ErrNoRows
|
|
}
|
|
s.debugDB("site target deleted", zap.Int64("site_id", siteID), zap.Int64("target_id", targetID))
|
|
return nil
|
|
}
|
|
|
|
func (s *Store) AddMySQLDumpTarget(ctx context.Context, siteID int64, dbHost, dbUser, dbName, dbPassword string) error {
|
|
_, err := s.db.ExecContext(
|
|
ctx,
|
|
`INSERT INTO site_targets (site_id, path, mode, mysql_host, mysql_user, mysql_db, mysql_password) VALUES (?, ?, 'mysql_dump', ?, ?, ?, ?)`,
|
|
siteID,
|
|
dbName,
|
|
dbHost,
|
|
dbUser,
|
|
dbName,
|
|
dbPassword,
|
|
)
|
|
if err == nil {
|
|
s.debugDB("mysql dump target added", zap.Int64("site_id", siteID), zap.String("db_host", dbHost), zap.String("db_user", dbUser), zap.String("db_name", dbName))
|
|
}
|
|
return err
|
|
}
|
|
|
|
func (s *Store) ListSites(ctx context.Context) ([]Site, error) {
|
|
const q = `
|
|
SELECT id, site_uuid, ssh_user, host, port, created_at, last_run_status, last_run_output, last_run_at, last_scan_at, last_scan_state, last_scan_notes
|
|
FROM sites
|
|
ORDER BY id DESC`
|
|
rows, err := s.db.QueryContext(ctx, q)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
var out []Site
|
|
for rows.Next() {
|
|
site, err := scanSite(rows)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
out = append(out, site)
|
|
}
|
|
if err := rows.Err(); err != nil {
|
|
return nil, err
|
|
}
|
|
if err := s.populateTargets(ctx, out); err != nil {
|
|
return nil, err
|
|
}
|
|
if err := s.populateFilters(ctx, out); err != nil {
|
|
return nil, err
|
|
}
|
|
return out, nil
|
|
}
|
|
|
|
func (s *Store) SiteByID(ctx context.Context, id int64) (Site, error) {
|
|
const q = `
|
|
SELECT id, site_uuid, ssh_user, host, port, created_at, last_run_status, last_run_output, last_run_at, last_scan_at, last_scan_state, last_scan_notes
|
|
FROM sites
|
|
WHERE id = ?`
|
|
site, err := scanSite(s.db.QueryRowContext(ctx, q, id))
|
|
if err != nil {
|
|
return Site{}, err
|
|
}
|
|
targets, err := s.targetsBySiteID(ctx, id)
|
|
if err != nil {
|
|
return Site{}, err
|
|
}
|
|
site.Targets = targets
|
|
filters, err := s.filtersBySiteID(ctx, id)
|
|
if err != nil {
|
|
return Site{}, err
|
|
}
|
|
site.Filters = filters
|
|
return site, nil
|
|
}
|
|
|
|
func (s *Store) UpdateSiteRunResult(ctx context.Context, id int64, status, output string, at time.Time) error {
|
|
_, err := s.db.ExecContext(
|
|
ctx,
|
|
`UPDATE sites SET last_run_status = ?, last_run_output = ?, last_run_at = ? WHERE id = ?`,
|
|
status,
|
|
output,
|
|
at.UTC().Format(time.RFC3339),
|
|
id,
|
|
)
|
|
if err == nil {
|
|
s.debugDB("site run updated", zap.Int64("site_id", id), zap.String("status", status), zap.Time("at", at.UTC()))
|
|
}
|
|
return err
|
|
}
|
|
|
|
func (s *Store) CreateJob(ctx context.Context, siteID int64, jobType string) (Job, error) {
|
|
res, err := s.db.ExecContext(
|
|
ctx,
|
|
`INSERT INTO jobs (site_id, type, status) VALUES (?, ?, 'queued')`,
|
|
siteID,
|
|
jobType,
|
|
)
|
|
if err != nil {
|
|
return Job{}, err
|
|
}
|
|
id, err := res.LastInsertId()
|
|
if err != nil {
|
|
return Job{}, err
|
|
}
|
|
s.debugDB("job created", zap.Int64("job_id", id), zap.Int64("site_id", siteID), zap.String("job_type", jobType))
|
|
return s.JobByID(ctx, id)
|
|
}
|
|
|
|
func (s *Store) JobByID(ctx context.Context, id int64) (Job, error) {
|
|
const q = `
|
|
SELECT id, site_id, type, status, summary, created_at, started_at, finished_at
|
|
FROM jobs
|
|
WHERE id = ?`
|
|
return scanJob(s.db.QueryRowContext(ctx, q, id))
|
|
}
|
|
|
|
func (s *Store) TryStartNextQueuedJob(ctx context.Context) (Job, bool, error) {
|
|
tx, err := s.db.BeginTx(ctx, nil)
|
|
if err != nil {
|
|
return Job{}, false, err
|
|
}
|
|
defer tx.Rollback()
|
|
|
|
var id int64
|
|
if err := tx.QueryRowContext(ctx, `SELECT id FROM jobs WHERE status = 'queued' ORDER BY id ASC LIMIT 1`).Scan(&id); err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return Job{}, false, nil
|
|
}
|
|
return Job{}, false, err
|
|
}
|
|
|
|
res, err := tx.ExecContext(ctx, `UPDATE jobs SET status = 'running', started_at = ? WHERE id = ? AND status = 'queued'`, time.Now().UTC().Format(time.RFC3339), id)
|
|
if err != nil {
|
|
return Job{}, false, err
|
|
}
|
|
affected, err := res.RowsAffected()
|
|
if err != nil {
|
|
return Job{}, false, err
|
|
}
|
|
if affected == 0 {
|
|
return Job{}, false, nil
|
|
}
|
|
|
|
job, err := scanJob(tx.QueryRowContext(ctx, `
|
|
SELECT id, site_id, type, status, summary, created_at, started_at, finished_at
|
|
FROM jobs
|
|
WHERE id = ?`, id))
|
|
if err != nil {
|
|
return Job{}, false, err
|
|
}
|
|
|
|
if err := tx.Commit(); err != nil {
|
|
return Job{}, false, err
|
|
}
|
|
s.debugDB("job started", zap.Int64("job_id", job.ID), zap.Int64("site_id", job.SiteID), zap.String("job_type", job.Type))
|
|
return job, true, nil
|
|
}
|
|
|
|
func (s *Store) CompleteJob(ctx context.Context, jobID int64, status, summary string) error {
|
|
_, err := s.db.ExecContext(
|
|
ctx,
|
|
`UPDATE jobs SET status = ?, summary = ?, finished_at = ? WHERE id = ?`,
|
|
status,
|
|
summary,
|
|
time.Now().UTC().Format(time.RFC3339),
|
|
jobID,
|
|
)
|
|
if err == nil {
|
|
s.debugDB("job completed", zap.Int64("job_id", jobID), zap.String("status", status), zap.String("summary", summary))
|
|
}
|
|
return err
|
|
}
|
|
|
|
func (s *Store) CancelQueuedJob(ctx context.Context, jobID int64, summary string) (bool, error) {
|
|
res, err := s.db.ExecContext(
|
|
ctx,
|
|
`UPDATE jobs
|
|
SET status = 'failed', summary = ?, finished_at = ?
|
|
WHERE id = ? AND status = 'queued'`,
|
|
summary,
|
|
time.Now().UTC().Format(time.RFC3339),
|
|
jobID,
|
|
)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
rows, err := res.RowsAffected()
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
if rows > 0 {
|
|
s.debugDB("job canceled from queue", zap.Int64("job_id", jobID), zap.String("summary", summary))
|
|
return true, nil
|
|
}
|
|
return false, nil
|
|
}
|
|
|
|
func (s *Store) AddJobEvent(ctx context.Context, event JobEvent) error {
|
|
_, err := s.db.ExecContext(
|
|
ctx,
|
|
`INSERT INTO job_events (job_id, level, message) VALUES (?, ?, ?)`,
|
|
event.JobID,
|
|
event.Level,
|
|
event.Message,
|
|
)
|
|
if err == nil {
|
|
s.debugDB("job event added", zap.Int64("job_id", event.JobID), zap.String("level", event.Level), zap.String("message", event.Message))
|
|
}
|
|
return err
|
|
}
|
|
|
|
func (s *Store) ListRecentJobs(ctx context.Context, limit int) ([]Job, error) {
|
|
if limit <= 0 {
|
|
limit = 20
|
|
}
|
|
rows, err := s.db.QueryContext(ctx, `
|
|
SELECT id, site_id, type, status, summary, created_at, started_at, finished_at
|
|
FROM jobs
|
|
ORDER BY id DESC
|
|
LIMIT ?`, limit)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
var out []Job
|
|
for rows.Next() {
|
|
job, err := scanJob(rows)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
out = append(out, job)
|
|
}
|
|
if err := rows.Err(); err != nil {
|
|
return nil, err
|
|
}
|
|
return out, nil
|
|
}
|
|
|
|
func (s *Store) ListRecentJobEvents(ctx context.Context, limit int) ([]JobEventRecord, error) {
|
|
if limit <= 0 {
|
|
limit = 50
|
|
}
|
|
rows, err := s.db.QueryContext(ctx, `
|
|
SELECT je.job_id, j.site_id, j.type, je.level, je.message, je.occurred_at
|
|
FROM job_events je
|
|
JOIN jobs j ON j.id = je.job_id
|
|
ORDER BY je.id DESC
|
|
LIMIT ?`, limit)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
var out []JobEventRecord
|
|
for rows.Next() {
|
|
var ev JobEventRecord
|
|
if err := rows.Scan(&ev.JobID, &ev.SiteID, &ev.JobType, &ev.Level, &ev.Message, &ev.OccurredAt); err != nil {
|
|
return nil, err
|
|
}
|
|
out = append(out, ev)
|
|
}
|
|
if err := rows.Err(); err != nil {
|
|
return nil, err
|
|
}
|
|
return out, nil
|
|
}
|
|
|
|
func (s *Store) UpdateSiteScanResult(ctx context.Context, siteID int64, state, notes string, scannedAt time.Time, targets []SiteTarget) error {
|
|
tx, err := s.db.BeginTx(ctx, nil)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer tx.Rollback()
|
|
|
|
for _, t := range targets {
|
|
if _, err := tx.ExecContext(
|
|
ctx,
|
|
`UPDATE site_targets SET last_size_bytes = ?, last_scan_at = ?, last_error = ? WHERE site_id = ? AND path = ? AND mode = ?`,
|
|
nullInt64Arg(t.LastSizeByte),
|
|
timeOrNil(t.LastScanAt),
|
|
nullStringArg(t.LastError),
|
|
siteID,
|
|
t.Path,
|
|
t.Mode,
|
|
); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
if _, err := tx.ExecContext(
|
|
ctx,
|
|
`UPDATE sites SET last_scan_at = ?, last_scan_state = ?, last_scan_notes = ? WHERE id = ?`,
|
|
scannedAt.UTC().Format(time.RFC3339),
|
|
state,
|
|
notes,
|
|
siteID,
|
|
); err != nil {
|
|
return err
|
|
}
|
|
s.debugDB("site scan updated", zap.Int64("site_id", siteID), zap.String("state", state), zap.Int("targets", len(targets)), zap.Time("scanned_at", scannedAt.UTC()))
|
|
|
|
return tx.Commit()
|
|
}
|
|
|
|
func userByIDTx(ctx context.Context, tx *sql.Tx, id int64) (User, error) {
|
|
const q = `
|
|
SELECT id, username, password_hash, is_admin, created_at
|
|
FROM users
|
|
WHERE id = ?`
|
|
return scanUser(tx.QueryRowContext(ctx, q, id))
|
|
}
|
|
|
|
type scanner interface {
|
|
Scan(dest ...any) error
|
|
}
|
|
|
|
func scanJob(row scanner) (Job, error) {
|
|
var job Job
|
|
if err := row.Scan(
|
|
&job.ID,
|
|
&job.SiteID,
|
|
&job.Type,
|
|
&job.Status,
|
|
&job.Summary,
|
|
&job.CreatedAt,
|
|
&job.StartedAt,
|
|
&job.FinishedAt,
|
|
); err != nil {
|
|
return Job{}, err
|
|
}
|
|
return job, nil
|
|
}
|
|
|
|
func scanUser(row scanner) (User, error) {
|
|
var user User
|
|
var isAdmin int
|
|
if err := row.Scan(&user.ID, &user.Username, &user.PasswordHash, &isAdmin, &user.CreatedAt); err != nil {
|
|
return User{}, err
|
|
}
|
|
user.IsAdmin = isAdmin == 1
|
|
return user, nil
|
|
}
|
|
|
|
func scanSite(row scanner) (Site, error) {
|
|
var site Site
|
|
if err := row.Scan(
|
|
&site.ID,
|
|
&site.SiteUUID,
|
|
&site.SSHUser,
|
|
&site.Host,
|
|
&site.Port,
|
|
&site.CreatedAt,
|
|
&site.LastRunStatus,
|
|
&site.LastRunOutput,
|
|
&site.LastRunAt,
|
|
&site.LastScanAt,
|
|
&site.LastScanState,
|
|
&site.LastScanNotes,
|
|
); err != nil {
|
|
return Site{}, err
|
|
}
|
|
return site, nil
|
|
}
|
|
|
|
func (s *Store) populateTargets(ctx context.Context, sites []Site) error {
|
|
if len(sites) == 0 {
|
|
return nil
|
|
}
|
|
targetsBySite, err := s.allTargetsBySiteID(ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
for i := range sites {
|
|
sites[i].Targets = targetsBySite[sites[i].ID]
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *Store) populateFilters(ctx context.Context, sites []Site) error {
|
|
if len(sites) == 0 {
|
|
return nil
|
|
}
|
|
filtersBySite, err := s.allFiltersBySiteID(ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
for i := range sites {
|
|
sites[i].Filters = filtersBySite[sites[i].ID]
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *Store) allTargetsBySiteID(ctx context.Context) (map[int64][]SiteTarget, error) {
|
|
const q = `SELECT site_id, id, path, mode, mysql_host, mysql_user, mysql_db, mysql_password, last_size_bytes, last_scan_at, last_error FROM site_targets ORDER BY id ASC`
|
|
rows, err := s.db.QueryContext(ctx, q)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
out := map[int64][]SiteTarget{}
|
|
for rows.Next() {
|
|
var siteID int64
|
|
var target SiteTarget
|
|
if err := rows.Scan(&siteID, &target.ID, &target.Path, &target.Mode, &target.MySQLHost, &target.MySQLUser, &target.MySQLDB, &target.MySQLPassword, &target.LastSizeByte, &target.LastScanAt, &target.LastError); err != nil {
|
|
return nil, err
|
|
}
|
|
out[siteID] = append(out[siteID], target)
|
|
}
|
|
if err := rows.Err(); err != nil {
|
|
return nil, err
|
|
}
|
|
return out, nil
|
|
}
|
|
|
|
func (s *Store) targetsBySiteID(ctx context.Context, siteID int64) ([]SiteTarget, error) {
|
|
const q = `SELECT id, path, mode, mysql_host, mysql_user, mysql_db, mysql_password, last_size_bytes, last_scan_at, last_error FROM site_targets WHERE site_id = ? ORDER BY id ASC`
|
|
rows, err := s.db.QueryContext(ctx, q, siteID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
var out []SiteTarget
|
|
for rows.Next() {
|
|
var target SiteTarget
|
|
if err := rows.Scan(&target.ID, &target.Path, &target.Mode, &target.MySQLHost, &target.MySQLUser, &target.MySQLDB, &target.MySQLPassword, &target.LastSizeByte, &target.LastScanAt, &target.LastError); err != nil {
|
|
return nil, err
|
|
}
|
|
out = append(out, target)
|
|
}
|
|
if err := rows.Err(); err != nil {
|
|
return nil, err
|
|
}
|
|
return out, nil
|
|
}
|
|
|
|
func (s *Store) allFiltersBySiteID(ctx context.Context) (map[int64][]string, error) {
|
|
rows, err := s.db.QueryContext(ctx, `SELECT site_id, pattern FROM site_filters ORDER BY id ASC`)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
out := map[int64][]string{}
|
|
for rows.Next() {
|
|
var siteID int64
|
|
var pattern string
|
|
if err := rows.Scan(&siteID, &pattern); err != nil {
|
|
return nil, err
|
|
}
|
|
out[siteID] = append(out[siteID], pattern)
|
|
}
|
|
if err := rows.Err(); err != nil {
|
|
return nil, err
|
|
}
|
|
return out, nil
|
|
}
|
|
|
|
func (s *Store) filtersBySiteID(ctx context.Context, siteID int64) ([]string, error) {
|
|
rows, err := s.db.QueryContext(ctx, `SELECT pattern FROM site_filters WHERE site_id = ? ORDER BY id ASC`, siteID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
var out []string
|
|
for rows.Next() {
|
|
var pattern string
|
|
if err := rows.Scan(&pattern); err != nil {
|
|
return nil, err
|
|
}
|
|
out = append(out, pattern)
|
|
}
|
|
if err := rows.Err(); err != nil {
|
|
return nil, err
|
|
}
|
|
return out, nil
|
|
}
|
|
|
|
func boolToInt(v bool) int {
|
|
if v {
|
|
return 1
|
|
}
|
|
return 0
|
|
}
|
|
|
|
var ErrUsernameTaken = errors.New("username already registered")
|
|
|
|
func isUniqueUsernameErr(err error) bool {
|
|
return strings.Contains(err.Error(), "UNIQUE constraint failed: users.username")
|
|
}
|
|
|
|
func nullInt64Arg(v sql.NullInt64) any {
|
|
if v.Valid {
|
|
return v.Int64
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func nullStringArg(v sql.NullString) any {
|
|
if v.Valid {
|
|
return v.String
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func timeOrNil(v sql.NullTime) any {
|
|
if v.Valid {
|
|
return v.Time.UTC().Format(time.RFC3339)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *Store) debugDB(msg string, fields ...zap.Field) {
|
|
if s.log == nil {
|
|
return
|
|
}
|
|
s.log.Debug(msg, fields...)
|
|
}
|
|
|
|
func newSiteUUID() (string, error) {
|
|
buf := make([]byte, 16)
|
|
if _, err := rand.Read(buf); err != nil {
|
|
return "", err
|
|
}
|
|
return hex.EncodeToString(buf), nil
|
|
}
|