package auth import ( "bytes" "context" "crypto/rand" "encoding/base64" "encoding/gob" "net/http" "time" "github.com/google/uuid" "github.com/rs/zerolog/log" "github.com/gorilla/securecookie" "github.com/gorilla/sessions" "github.com/mattnite/forgejo-tickets/internal/models" "gorm.io/gorm" ) const ( sessionCookieName = "session" sessionMaxAge = 86400 * 30 // 30 days ) type PGStore struct { db *gorm.DB codecs []securecookie.Codec options *sessions.Options } func NewPGStore(db *gorm.DB, keyPairs ...[]byte) *PGStore { return &PGStore{ db: db, codecs: securecookie.CodecsFromPairs(keyPairs...), options: &sessions.Options{ Path: "/", MaxAge: sessionMaxAge, HttpOnly: true, SameSite: http.SameSiteLaxMode, }, } } func (s *PGStore) Get(r *http.Request, name string) (*sessions.Session, error) { return sessions.GetRegistry(r).Get(s, name) } func (s *PGStore) New(r *http.Request, name string) (*sessions.Session, error) { session := sessions.NewSession(s, name) session.Options = &sessions.Options{ Path: s.options.Path, MaxAge: s.options.MaxAge, HttpOnly: s.options.HttpOnly, SameSite: s.options.SameSite, Secure: s.options.Secure, } session.IsNew = true cookie, err := r.Cookie(name) if err != nil { return session, nil } err = securecookie.DecodeMulti(name, cookie.Value, &session.ID, s.codecs...) if err != nil { return session, nil } var dbSession models.Session result := s.db.Where("token = ? AND expires_at > ?", session.ID, time.Now()).First(&dbSession) if result.Error != nil { return session, nil } if err := gob.NewDecoder(bytes.NewReader(dbSession.Data)).Decode(&session.Values); err != nil { return session, nil } session.Values["user_id"] = dbSession.UserID.String() session.IsNew = false return session, nil } func (s *PGStore) Save(r *http.Request, w http.ResponseWriter, session *sessions.Session) error { if session.Options.MaxAge < 0 { if session.ID != "" { s.db.Where("token = ?", session.ID).Delete(&models.Session{}) } http.SetCookie(w, sessions.NewCookie(session.Name(), "", session.Options)) return nil } if session.ID == "" { token := make([]byte, 32) if _, err := rand.Read(token); err != nil { return err } session.ID = base64.URLEncoding.EncodeToString(token) } userIDStr, _ := session.Values["user_id"].(string) userID, err := uuid.Parse(userIDStr) if err != nil { return err } valuesToEncode := make(map[interface{}]interface{}) for k, v := range session.Values { if k != "user_id" { valuesToEncode[k] = v } } var buf bytes.Buffer if err := gob.NewEncoder(&buf).Encode(valuesToEncode); err != nil { return err } expiresAt := time.Now().Add(time.Duration(session.Options.MaxAge) * time.Second) dbSession := models.Session{ Token: session.ID, UserID: userID, Data: buf.Bytes(), ExpiresAt: expiresAt, } result := s.db.Where("token = ?", session.ID).Assign(models.Session{ UserID: userID, Data: buf.Bytes(), ExpiresAt: expiresAt, }).FirstOrCreate(&dbSession) if result.Error != nil { return result.Error } encoded, err := securecookie.EncodeMulti(session.Name(), session.ID, s.codecs...) if err != nil { return err } http.SetCookie(w, sessions.NewCookie(session.Name(), encoded, session.Options)) return nil } func (s *PGStore) Cleanup(ctx context.Context, interval time.Duration) { ticker := time.NewTicker(interval) defer ticker.Stop() for { select { case <-ctx.Done(): return case <-ticker.C: if err := s.db.Where("expires_at <= ?", time.Now()).Delete(&models.Session{}).Error; err != nil { log.Error().Err(err).Msg("session cleanup error") } } } }