satoru/cmd/satoru/main.go

639 lines
17 KiB
Go

package main
import (
"context"
"crypto/rand"
"crypto/sha256"
"database/sql"
"encoding/base64"
"encoding/hex"
"errors"
"fmt"
"log"
"net/http"
"os"
"os/exec"
"path/filepath"
"regexp"
"strconv"
"strings"
"time"
"github.com/a-h/templ"
"github.com/go-chi/chi/v5"
"golang.org/x/crypto/bcrypt"
"satoru/internal/store"
"satoru/internal/webui"
)
const (
sessionCookieName = "satoru_session"
sessionTTL = 24 * time.Hour * 7
scanInterval = 24 * time.Hour
scanLoopTick = time.Hour
)
type app struct {
store *store.Store
}
func main() {
if err := os.MkdirAll("data", 0o755); err != nil {
log.Fatal(err)
}
dbPath := filepath.Join("data", "satoru.db")
st, err := store.Open(dbPath)
if err != nil {
log.Fatal(err)
}
defer st.Close()
a := &app{store: st}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go a.startSiteScanLoop(ctx)
r := chi.NewRouter()
fileServer := http.FileServer(http.Dir("web/static"))
r.Handle("/static/*", http.StripPrefix("/static/", fileServer))
r.Get("/", a.handleHome)
r.Get("/account/password", a.handlePasswordPage)
r.Post("/account/password", a.handlePasswordSubmit)
r.Post("/sites", a.handleSiteCreate)
r.Post("/sites/{id}/run", a.handleSiteRun)
r.Get("/signup", a.handleSignupPage)
r.Post("/signup", a.handleSignupSubmit)
r.Get("/signin", a.handleSigninPage)
r.Post("/signin", a.handleSigninSubmit)
r.Post("/signout", a.handleSignoutSubmit)
addr := ":8080"
log.Printf("satoru listening on http://localhost%s", addr)
if err := http.ListenAndServe(addr, r); err != nil {
log.Fatal(err)
}
}
func (a *app) handleHome(w http.ResponseWriter, r *http.Request) {
user, err := a.currentUserWithRollingSession(w, r)
if err != nil {
templ.Handler(webui.Home(time.Now(), store.User{})).ServeHTTP(w, r)
return
}
sites, err := a.store.ListSites(r.Context())
if err != nil {
http.Error(w, "failed to load sites", http.StatusInternalServerError)
return
}
data := webui.DashboardData{
Now: time.Now(),
User: user,
Sites: sites,
RuntimeChecks: runtimeChecks(),
FlashMessage: r.URL.Query().Get("msg"),
WorkflowStages: defaultWorkflowStages(),
}
templ.Handler(webui.Dashboard(data)).ServeHTTP(w, r)
}
func (a *app) handleSignupPage(w http.ResponseWriter, r *http.Request) {
if _, err := a.currentUserWithRollingSession(w, r); err == nil {
http.Redirect(w, r, "/", http.StatusSeeOther)
return
}
templ.Handler(webui.Signup(webui.AuthPageData{})).ServeHTTP(w, r)
}
func (a *app) handleSignupSubmit(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
http.Error(w, "invalid form", http.StatusBadRequest)
return
}
username := normalizeUsername(r.FormValue("username"))
password := r.FormValue("password")
form := webui.AuthPageData{Username: username}
if !validUsername(username) {
form.Error = "Username must be 3-32 chars using letters, numbers, ., _, or -."
templ.Handler(webui.Signup(form)).ServeHTTP(w, r)
return
}
if len(password) < 8 {
form.Error = "Password must be at least 8 characters."
templ.Handler(webui.Signup(form)).ServeHTTP(w, r)
return
}
hashBytes, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
http.Error(w, "failed to create account", http.StatusInternalServerError)
return
}
user, err := a.store.CreateUser(r.Context(), username, string(hashBytes))
if err != nil {
if errors.Is(err, store.ErrUsernameTaken) {
form.Error = "That username is already registered."
templ.Handler(webui.Signup(form)).ServeHTTP(w, r)
return
}
http.Error(w, "failed to create account", http.StatusInternalServerError)
return
}
if err := a.issueSession(w, r, user.ID); err != nil {
http.Error(w, "failed to create session", http.StatusInternalServerError)
return
}
http.Redirect(w, r, "/", http.StatusSeeOther)
}
func (a *app) handleSigninPage(w http.ResponseWriter, r *http.Request) {
if _, err := a.currentUserWithRollingSession(w, r); err == nil {
http.Redirect(w, r, "/", http.StatusSeeOther)
return
}
templ.Handler(webui.Signin(webui.AuthPageData{})).ServeHTTP(w, r)
}
func (a *app) handleSigninSubmit(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
http.Error(w, "invalid form", http.StatusBadRequest)
return
}
username := normalizeUsername(r.FormValue("username"))
password := r.FormValue("password")
form := webui.AuthPageData{Username: username}
user, err := a.store.UserByUsername(r.Context(), username)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
form.Error = "Invalid username or password."
templ.Handler(webui.Signin(form)).ServeHTTP(w, r)
return
}
http.Error(w, "failed to sign in", http.StatusInternalServerError)
return
}
if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(password)); err != nil {
form.Error = "Invalid username or password."
templ.Handler(webui.Signin(form)).ServeHTTP(w, r)
return
}
if err := a.issueSession(w, r, user.ID); err != nil {
http.Error(w, "failed to create session", http.StatusInternalServerError)
return
}
http.Redirect(w, r, "/", http.StatusSeeOther)
}
func (a *app) handleSignoutSubmit(w http.ResponseWriter, r *http.Request) {
c, err := r.Cookie(sessionCookieName)
if err == nil && c.Value != "" {
_ = a.store.DeleteSessionByTokenHash(r.Context(), hashToken(c.Value))
}
clearSessionCookie(w)
http.Redirect(w, r, "/signin", http.StatusSeeOther)
}
func (a *app) handlePasswordPage(w http.ResponseWriter, r *http.Request) {
user, err := a.currentUserWithRollingSession(w, r)
if err != nil {
http.Redirect(w, r, "/signin", http.StatusSeeOther)
return
}
templ.Handler(webui.ChangePassword(webui.PasswordPageData{User: user})).ServeHTTP(w, r)
}
func (a *app) handlePasswordSubmit(w http.ResponseWriter, r *http.Request) {
user, err := a.currentUserWithRollingSession(w, r)
if err != nil {
http.Redirect(w, r, "/signin", http.StatusSeeOther)
return
}
if err := r.ParseForm(); err != nil {
http.Error(w, "invalid form", http.StatusBadRequest)
return
}
current := r.FormValue("current_password")
next := r.FormValue("new_password")
confirm := r.FormValue("confirm_password")
form := webui.PasswordPageData{User: user}
if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(current)); err != nil {
form.Error = "Current password is incorrect."
templ.Handler(webui.ChangePassword(form)).ServeHTTP(w, r)
return
}
if len(next) < 8 {
form.Error = "New password must be at least 8 characters."
templ.Handler(webui.ChangePassword(form)).ServeHTTP(w, r)
return
}
if next != confirm {
form.Error = "New password and confirmation do not match."
templ.Handler(webui.ChangePassword(form)).ServeHTTP(w, r)
return
}
hashBytes, err := bcrypt.GenerateFromPassword([]byte(next), bcrypt.DefaultCost)
if err != nil {
http.Error(w, "failed to update password", http.StatusInternalServerError)
return
}
if err := a.store.UpdateUserPasswordHash(r.Context(), user.ID, string(hashBytes)); err != nil {
http.Error(w, "failed to update password", http.StatusInternalServerError)
return
}
http.Redirect(w, r, "/?msg=password-updated", http.StatusSeeOther)
}
func (a *app) handleSiteCreate(w http.ResponseWriter, r *http.Request) {
if _, err := a.currentUserWithRollingSession(w, r); err != nil {
http.Redirect(w, r, "/signin", http.StatusSeeOther)
return
}
if err := r.ParseForm(); err != nil {
http.Error(w, "invalid form", http.StatusBadRequest)
return
}
sshUser := strings.TrimSpace(r.FormValue("ssh_user"))
host := strings.TrimSpace(r.FormValue("host"))
port, err := parsePort(r.FormValue("port"))
if err != nil {
http.Redirect(w, r, "/?msg=site-invalid-port", http.StatusSeeOther)
return
}
directoryPaths := parsePathList(r.FormValue("directory_paths"))
sqlitePaths := parsePathList(r.FormValue("sqlite_paths"))
targets := buildTargets(directoryPaths, sqlitePaths)
if sshUser == "" || host == "" || len(targets) == 0 {
http.Redirect(w, r, "/?msg=site-invalid", http.StatusSeeOther)
return
}
if !targetsAreValid(targets) {
http.Redirect(w, r, "/?msg=site-invalid-path", http.StatusSeeOther)
return
}
site, err := a.store.CreateSite(r.Context(), sshUser, host, port, targets)
if err != nil {
http.Error(w, "failed to add site", http.StatusInternalServerError)
return
}
scanCtx, cancel := context.WithTimeout(context.Background(), 45*time.Second)
a.scanSiteNow(scanCtx, site.ID)
cancel()
http.Redirect(w, r, "/?msg=site-added", http.StatusSeeOther)
}
func (a *app) handleSiteRun(w http.ResponseWriter, r *http.Request) {
if _, err := a.currentUserWithRollingSession(w, r); err != nil {
http.Redirect(w, r, "/signin", http.StatusSeeOther)
return
}
id, err := strconv.ParseInt(chi.URLParam(r, "id"), 10, 64)
if err != nil {
http.Error(w, "invalid site id", http.StatusBadRequest)
return
}
site, err := a.store.SiteByID(r.Context(), id)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
http.NotFound(w, r)
return
}
http.Error(w, "failed to load site", http.StatusInternalServerError)
return
}
status, output := runSSHHello(r.Context(), site)
if err := a.store.UpdateSiteRunResult(r.Context(), site.ID, status, output, time.Now()); err != nil {
http.Error(w, "failed to store run result", http.StatusInternalServerError)
return
}
http.Redirect(w, r, "/?msg=site-ran", http.StatusSeeOther)
}
func (a *app) issueSession(w http.ResponseWriter, r *http.Request, userID int64) error {
token, err := generateToken()
if err != nil {
return err
}
expiresAt := time.Now().Add(sessionTTL)
if err := a.store.CreateSession(r.Context(), userID, hashToken(token), expiresAt); err != nil {
return err
}
setSessionCookie(w, r, token, expiresAt)
return nil
}
func setSessionCookie(w http.ResponseWriter, r *http.Request, token string, expiresAt time.Time) {
http.SetCookie(w, &http.Cookie{
Name: sessionCookieName,
Value: token,
Path: "/",
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
Secure: r.TLS != nil,
Expires: expiresAt,
})
}
func clearSessionCookie(w http.ResponseWriter) {
http.SetCookie(w, &http.Cookie{
Name: sessionCookieName,
Value: "",
Path: "/",
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
MaxAge: -1,
Expires: time.Unix(0, 0),
})
}
func (a *app) currentUser(ctx context.Context, r *http.Request) (store.User, string, error) {
c, err := r.Cookie(sessionCookieName)
if err != nil || c.Value == "" {
return store.User{}, "", http.ErrNoCookie
}
user, err := a.store.UserBySessionTokenHash(ctx, hashToken(c.Value))
if err != nil {
return store.User{}, "", err
}
return user, c.Value, nil
}
func (a *app) currentUserWithRollingSession(w http.ResponseWriter, r *http.Request) (store.User, error) {
user, token, err := a.currentUser(r.Context(), r)
if err != nil {
return store.User{}, err
}
expiresAt := time.Now().Add(sessionTTL)
if err := a.store.TouchSessionByTokenHash(r.Context(), hashToken(token), expiresAt); err != nil {
clearSessionCookie(w)
return store.User{}, err
}
setSessionCookie(w, r, token, expiresAt)
return user, nil
}
func generateToken() (string, error) {
buf := make([]byte, 32)
if _, err := rand.Read(buf); err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(buf), nil
}
func hashToken(token string) string {
sum := sha256.Sum256([]byte(token))
return hex.EncodeToString(sum[:])
}
var usernamePattern = regexp.MustCompile(`^[a-z0-9._-]{3,32}$`)
func normalizeUsername(v string) string {
return strings.ToLower(strings.TrimSpace(v))
}
func validUsername(v string) bool {
return usernamePattern.MatchString(v)
}
func runtimeChecks() []webui.RuntimeCheck {
tools := []string{"restic", "rsync", "ssh"}
out := make([]webui.RuntimeCheck, 0, len(tools))
for _, name := range tools {
path, err := exec.LookPath(name)
if err != nil {
out = append(out, webui.RuntimeCheck{Name: name, Installed: false, Details: "not found in PATH"})
continue
}
out = append(out, webui.RuntimeCheck{Name: name, Installed: true, Details: path})
}
return out
}
func runSSHHello(ctx context.Context, site store.Site) (string, string) {
target := fmt.Sprintf("%s@%s", site.SSHUser, site.Host)
cmdCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
cmd := exec.CommandContext(cmdCtx, "ssh", "-p", strconv.Itoa(site.Port), target, "echo hello from satoru")
out, err := cmd.CombinedOutput()
output := strings.TrimSpace(string(out))
if output == "" {
output = "(no output)"
}
if err != nil {
return "failed", output
}
return "ok", output
}
func defaultWorkflowStages() []webui.WorkflowStage {
return []webui.WorkflowStage{
{Title: "Pull from Edge over SSH", Description: "Satoru connects to Linux edge hosts using local keys and pulls approved paths."},
{Title: "Stage on Backup Server", Description: "Pulled data lands on the backup host first, keeping edge systems isolated from B2."},
{Title: "Restic to B2", Description: "Restic runs centrally on this Satoru instance and uploads snapshots to Backblaze B2."},
{Title: "Audit and Recover", Description: "Each site records run output/status for operational visibility before full job history is added."},
}
}
func parsePathList(raw string) []string {
split := strings.FieldsFunc(raw, func(r rune) bool {
return r == '\n' || r == ',' || r == ';'
})
out := make([]string, 0, len(split))
for _, item := range split {
item = strings.TrimSpace(item)
if item == "" {
continue
}
out = append(out, item)
}
return out
}
func buildTargets(directoryPaths, sqlitePaths []string) []store.SiteTarget {
out := make([]store.SiteTarget, 0, len(directoryPaths)+len(sqlitePaths))
for _, p := range directoryPaths {
out = append(out, store.SiteTarget{Path: p, Mode: "directory"})
}
for _, p := range sqlitePaths {
out = append(out, store.SiteTarget{Path: p, Mode: "sqlite_dump"})
}
return out
}
func targetsAreValid(targets []store.SiteTarget) bool {
for _, t := range targets {
if t.Path == "" || !strings.HasPrefix(t.Path, "/") {
return false
}
}
return true
}
func parsePort(raw string) (int, error) {
raw = strings.TrimSpace(raw)
if raw == "" {
return 22, nil
}
port, err := strconv.Atoi(raw)
if err != nil || port < 1 || port > 65535 {
return 0, errors.New("invalid port")
}
return port, nil
}
func (a *app) startSiteScanLoop(ctx context.Context) {
a.scanAllSites(ctx)
ticker := time.NewTicker(scanLoopTick)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
a.scanDueSites(ctx)
}
}
}
func (a *app) scanAllSites(ctx context.Context) {
sites, err := a.store.ListSites(ctx)
if err != nil {
log.Printf("scan loop: failed to list sites: %v", err)
return
}
for _, site := range sites {
a.scanSiteNow(ctx, site.ID)
}
}
func (a *app) scanDueSites(ctx context.Context) {
sites, err := a.store.ListSites(ctx)
if err != nil {
log.Printf("scan loop: failed to list sites: %v", err)
return
}
now := time.Now()
for _, site := range sites {
if site.LastScanAt.Valid && site.LastScanAt.Time.Add(scanInterval).After(now) {
continue
}
a.scanSiteNow(ctx, site.ID)
}
}
func (a *app) scanSiteNow(ctx context.Context, siteID int64) {
site, err := a.store.SiteByID(ctx, siteID)
if err != nil {
log.Printf("scan site %d: load failed: %v", siteID, err)
return
}
scannedAt := time.Now()
success := 0
failures := 0
updated := make([]store.SiteTarget, 0, len(site.Targets))
for _, target := range site.Targets {
size, outErr := queryTargetSize(ctx, site, target)
target.LastScanAt = sql.NullTime{Time: scannedAt, Valid: true}
if outErr != nil {
failures++
target.LastSizeByte = sql.NullInt64{}
target.LastError = sql.NullString{String: outErr.Error(), Valid: true}
} else {
success++
target.LastSizeByte = sql.NullInt64{Int64: size, Valid: true}
target.LastError = sql.NullString{}
}
updated = append(updated, target)
}
state := "ok"
switch {
case len(site.Targets) == 0:
state = "failed"
case failures == len(site.Targets):
state = "failed"
case failures > 0:
state = "partial"
}
notes := fmt.Sprintf("%d/%d targets scanned", success, len(site.Targets))
if err := a.store.UpdateSiteScanResult(ctx, site.ID, state, notes, scannedAt, updated); err != nil {
log.Printf("scan site %d: update failed: %v", siteID, err)
}
}
func queryTargetSize(ctx context.Context, site store.Site, target store.SiteTarget) (int64, error) {
targetAddr := fmt.Sprintf("%s@%s", site.SSHUser, site.Host)
cmdCtx, cancel := context.WithTimeout(ctx, 20*time.Second)
defer cancel()
remote := remoteSizeCommand(target)
cmd := exec.CommandContext(cmdCtx, "ssh", "-p", strconv.Itoa(site.Port), targetAddr, remote)
out, err := cmd.CombinedOutput()
output := strings.TrimSpace(string(out))
if err != nil {
if output == "" {
output = err.Error()
}
return 0, errors.New(output)
}
size, ok := extractLastInteger(output)
if !ok {
return 0, errors.New("empty size output")
}
return size, nil
}
func remoteSizeCommand(target store.SiteTarget) string {
path := shellQuote(target.Path)
if target.Mode == "sqlite_dump" {
return fmt.Sprintf("stat -c%%s -- %s", path)
}
return fmt.Sprintf("du -sb -- %s | awk '{print $1}'", path)
}
func shellQuote(s string) string {
return "'" + strings.ReplaceAll(s, "'", `'\''`) + "'"
}
func extractLastInteger(output string) (int64, bool) {
fields := strings.Fields(output)
for i := len(fields) - 1; i >= 0; i-- {
v, err := strconv.ParseInt(fields[i], 10, 64)
if err == nil {
return v, true
}
}
return 0, false
}