forgejo-tickets/internal/auth/store.go

159 lines
3.6 KiB
Go

package auth
import (
"bytes"
"context"
"crypto/rand"
"encoding/base64"
"encoding/gob"
"net/http"
"time"
"github.com/google/uuid"
"github.com/rs/zerolog/log"
"github.com/gorilla/securecookie"
"github.com/gorilla/sessions"
"github.com/mattnite/forgejo-tickets/internal/models"
"gorm.io/gorm"
)
const (
sessionCookieName = "session"
sessionMaxAge = 86400 * 30 // 30 days
)
type PGStore struct {
db *gorm.DB
codecs []securecookie.Codec
options *sessions.Options
}
func NewPGStore(db *gorm.DB, keyPairs ...[]byte) *PGStore {
return &PGStore{
db: db,
codecs: securecookie.CodecsFromPairs(keyPairs...),
options: &sessions.Options{
Path: "/",
MaxAge: sessionMaxAge,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
},
}
}
func (s *PGStore) Get(r *http.Request, name string) (*sessions.Session, error) {
return sessions.GetRegistry(r).Get(s, name)
}
func (s *PGStore) New(r *http.Request, name string) (*sessions.Session, error) {
session := sessions.NewSession(s, name)
session.Options = &sessions.Options{
Path: s.options.Path,
MaxAge: s.options.MaxAge,
HttpOnly: s.options.HttpOnly,
SameSite: s.options.SameSite,
Secure: s.options.Secure,
}
session.IsNew = true
cookie, err := r.Cookie(name)
if err != nil {
return session, nil
}
err = securecookie.DecodeMulti(name, cookie.Value, &session.ID, s.codecs...)
if err != nil {
return session, nil
}
var dbSession models.Session
result := s.db.Where("token = ? AND expires_at > ?", session.ID, time.Now()).First(&dbSession)
if result.Error != nil {
return session, nil
}
if err := gob.NewDecoder(bytes.NewReader(dbSession.Data)).Decode(&session.Values); err != nil {
return session, nil
}
session.Values["user_id"] = dbSession.UserID.String()
session.IsNew = false
return session, nil
}
func (s *PGStore) Save(r *http.Request, w http.ResponseWriter, session *sessions.Session) error {
if session.Options.MaxAge < 0 {
if session.ID != "" {
s.db.Where("token = ?", session.ID).Delete(&models.Session{})
}
http.SetCookie(w, sessions.NewCookie(session.Name(), "", session.Options))
return nil
}
if session.ID == "" {
token := make([]byte, 32)
if _, err := rand.Read(token); err != nil {
return err
}
session.ID = base64.URLEncoding.EncodeToString(token)
}
userIDStr, _ := session.Values["user_id"].(string)
userID, err := uuid.Parse(userIDStr)
if err != nil {
return err
}
valuesToEncode := make(map[interface{}]interface{})
for k, v := range session.Values {
if k != "user_id" {
valuesToEncode[k] = v
}
}
var buf bytes.Buffer
if err := gob.NewEncoder(&buf).Encode(valuesToEncode); err != nil {
return err
}
expiresAt := time.Now().Add(time.Duration(session.Options.MaxAge) * time.Second)
dbSession := models.Session{
Token: session.ID,
UserID: userID,
Data: buf.Bytes(),
ExpiresAt: expiresAt,
}
result := s.db.Where("token = ?", session.ID).Assign(models.Session{
Data: buf.Bytes(),
ExpiresAt: expiresAt,
}).FirstOrCreate(&dbSession)
if result.Error != nil {
return result.Error
}
encoded, err := securecookie.EncodeMulti(session.Name(), session.ID, s.codecs...)
if err != nil {
return err
}
http.SetCookie(w, sessions.NewCookie(session.Name(), encoded, session.Options))
return nil
}
func (s *PGStore) Cleanup(ctx context.Context, interval time.Duration) {
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if err := s.db.Where("expires_at <= ?", time.Now()).Delete(&models.Session{}).Error; err != nil {
log.Error().Err(err).Msg("session cleanup error")
}
}
}
}