package auth import ( "context" "crypto/rand" "crypto/sha256" "encoding/hex" "fmt" "time" "github.com/google/uuid" "github.com/mattnite/forgejo-tickets/internal/models" "github.com/rs/zerolog/log" "golang.org/x/crypto/bcrypt" ) func (s *Service) GenerateVerificationToken(ctx context.Context, userID uuid.UUID) (string, error) { return s.generateToken(ctx, userID, models.TokenTypeVerifyEmail, 24*time.Hour) } func (s *Service) GeneratePasswordResetToken(ctx context.Context, userID uuid.UUID) (string, error) { return s.generateToken(ctx, userID, models.TokenTypeResetPassword, 1*time.Hour) } func (s *Service) VerifyEmailToken(ctx context.Context, plainToken string) (*models.User, error) { return s.redeemToken(ctx, plainToken, models.TokenTypeVerifyEmail, func(ctx context.Context, userID uuid.UUID) error { return s.db.WithContext(ctx).Model(&models.User{}).Where("id = ?", userID).Update("email_verified", true).Error }) } func (s *Service) RedeemPasswordResetToken(ctx context.Context, plainToken, newPassword string) (*models.User, error) { return s.redeemToken(ctx, plainToken, models.TokenTypeResetPassword, func(ctx context.Context, userID uuid.UUID) error { hash, err := bcrypt.GenerateFromPassword([]byte(newPassword), bcrypt.DefaultCost) if err != nil { return err } hashStr := string(hash) return s.db.WithContext(ctx).Model(&models.User{}).Where("id = ?", userID).Update("password_hash", hashStr).Error }) } func (s *Service) generateToken(ctx context.Context, userID uuid.UUID, tokenType models.TokenType, ttl time.Duration) (string, error) { tokenBytes := make([]byte, 32) if _, err := rand.Read(tokenBytes); err != nil { return "", err } plainToken := hex.EncodeToString(tokenBytes) tokenHash := hashToken(plainToken) emailToken := models.EmailToken{ UserID: userID, TokenHash: tokenHash, TokenType: tokenType, ExpiresAt: time.Now().Add(ttl), } if err := s.db.WithContext(ctx).Create(&emailToken).Error; err != nil { return "", fmt.Errorf("create token: %w", err) } return plainToken, nil } func (s *Service) redeemToken(ctx context.Context, plainToken string, tokenType models.TokenType, action func(context.Context, uuid.UUID) error) (*models.User, error) { tokenHash := hashToken(plainToken) var token models.EmailToken if err := s.db.WithContext(ctx).Where("token_hash = ? AND token_type = ? AND expires_at > ? AND used_at IS NULL", tokenHash, tokenType, time.Now()).First(&token).Error; err != nil { return nil, fmt.Errorf("invalid or expired token") } if err := action(ctx, token.UserID); err != nil { return nil, err } now := time.Now() s.db.WithContext(ctx).Model(&token).Update("used_at", &now) var user models.User if err := s.db.WithContext(ctx).First(&user, "id = ?", token.UserID).Error; err != nil { return nil, err } return &user, nil } // CleanupExpiredTokens periodically deletes expired and used email tokens. func (s *Service) CleanupExpiredTokens(ctx context.Context, interval time.Duration) { ticker := time.NewTicker(interval) defer ticker.Stop() for { select { case <-ctx.Done(): return case <-ticker.C: result := s.db.Where("expires_at <= ? OR used_at IS NOT NULL", time.Now()).Delete(&models.EmailToken{}) if result.Error != nil { log.Error().Err(result.Error).Msg("email token cleanup error") } } } } func hashToken(plainToken string) string { h := sha256.Sum256([]byte(plainToken)) return hex.EncodeToString(h[:]) }