159 lines
3.6 KiB
Go
159 lines
3.6 KiB
Go
package auth
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/rand"
|
|
"encoding/base64"
|
|
"encoding/gob"
|
|
"net/http"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/rs/zerolog/log"
|
|
"github.com/gorilla/securecookie"
|
|
"github.com/gorilla/sessions"
|
|
"github.com/mattnite/forgejo-tickets/internal/models"
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
const (
|
|
sessionCookieName = "session"
|
|
sessionMaxAge = 86400 * 30 // 30 days
|
|
)
|
|
|
|
type PGStore struct {
|
|
db *gorm.DB
|
|
codecs []securecookie.Codec
|
|
options *sessions.Options
|
|
}
|
|
|
|
func NewPGStore(db *gorm.DB, keyPairs ...[]byte) *PGStore {
|
|
return &PGStore{
|
|
db: db,
|
|
codecs: securecookie.CodecsFromPairs(keyPairs...),
|
|
options: &sessions.Options{
|
|
Path: "/",
|
|
MaxAge: sessionMaxAge,
|
|
HttpOnly: true,
|
|
SameSite: http.SameSiteLaxMode,
|
|
},
|
|
}
|
|
}
|
|
|
|
func (s *PGStore) Get(r *http.Request, name string) (*sessions.Session, error) {
|
|
return sessions.GetRegistry(r).Get(s, name)
|
|
}
|
|
|
|
func (s *PGStore) New(r *http.Request, name string) (*sessions.Session, error) {
|
|
session := sessions.NewSession(s, name)
|
|
session.Options = &sessions.Options{
|
|
Path: s.options.Path,
|
|
MaxAge: s.options.MaxAge,
|
|
HttpOnly: s.options.HttpOnly,
|
|
SameSite: s.options.SameSite,
|
|
Secure: s.options.Secure,
|
|
}
|
|
session.IsNew = true
|
|
|
|
cookie, err := r.Cookie(name)
|
|
if err != nil {
|
|
return session, nil
|
|
}
|
|
|
|
err = securecookie.DecodeMulti(name, cookie.Value, &session.ID, s.codecs...)
|
|
if err != nil {
|
|
return session, nil
|
|
}
|
|
|
|
var dbSession models.Session
|
|
result := s.db.Where("token = ? AND expires_at > ?", session.ID, time.Now()).First(&dbSession)
|
|
if result.Error != nil {
|
|
return session, nil
|
|
}
|
|
|
|
if err := gob.NewDecoder(bytes.NewReader(dbSession.Data)).Decode(&session.Values); err != nil {
|
|
return session, nil
|
|
}
|
|
|
|
session.Values["user_id"] = dbSession.UserID.String()
|
|
session.IsNew = false
|
|
return session, nil
|
|
}
|
|
|
|
func (s *PGStore) Save(r *http.Request, w http.ResponseWriter, session *sessions.Session) error {
|
|
if session.Options.MaxAge < 0 {
|
|
if session.ID != "" {
|
|
s.db.Where("token = ?", session.ID).Delete(&models.Session{})
|
|
}
|
|
http.SetCookie(w, sessions.NewCookie(session.Name(), "", session.Options))
|
|
return nil
|
|
}
|
|
|
|
if session.ID == "" {
|
|
token := make([]byte, 32)
|
|
if _, err := rand.Read(token); err != nil {
|
|
return err
|
|
}
|
|
session.ID = base64.URLEncoding.EncodeToString(token)
|
|
}
|
|
|
|
userIDStr, _ := session.Values["user_id"].(string)
|
|
userID, err := uuid.Parse(userIDStr)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
valuesToEncode := make(map[interface{}]interface{})
|
|
for k, v := range session.Values {
|
|
if k != "user_id" {
|
|
valuesToEncode[k] = v
|
|
}
|
|
}
|
|
|
|
var buf bytes.Buffer
|
|
if err := gob.NewEncoder(&buf).Encode(valuesToEncode); err != nil {
|
|
return err
|
|
}
|
|
|
|
expiresAt := time.Now().Add(time.Duration(session.Options.MaxAge) * time.Second)
|
|
|
|
dbSession := models.Session{
|
|
Token: session.ID,
|
|
UserID: userID,
|
|
Data: buf.Bytes(),
|
|
ExpiresAt: expiresAt,
|
|
}
|
|
|
|
result := s.db.Where("token = ?", session.ID).Assign(models.Session{
|
|
Data: buf.Bytes(),
|
|
ExpiresAt: expiresAt,
|
|
}).FirstOrCreate(&dbSession)
|
|
if result.Error != nil {
|
|
return result.Error
|
|
}
|
|
|
|
encoded, err := securecookie.EncodeMulti(session.Name(), session.ID, s.codecs...)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
http.SetCookie(w, sessions.NewCookie(session.Name(), encoded, session.Options))
|
|
return nil
|
|
}
|
|
|
|
func (s *PGStore) Cleanup(ctx context.Context, interval time.Duration) {
|
|
ticker := time.NewTicker(interval)
|
|
defer ticker.Stop()
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case <-ticker.C:
|
|
if err := s.db.Where("expires_at <= ?", time.Now()).Delete(&models.Session{}).Error; err != nil {
|
|
log.Error().Err(err).Msg("session cleanup error")
|
|
}
|
|
}
|
|
}
|
|
}
|