package public import ( "crypto/rand" "encoding/hex" "net/http" "strings" "github.com/gin-gonic/gin" "github.com/mattnite/forgejo-tickets/internal/auth" "github.com/mattnite/forgejo-tickets/internal/middleware" "github.com/rs/zerolog/log" "golang.org/x/oauth2" ) type OAuthHandler struct { deps Dependencies } func (h *OAuthHandler) getProvider(providerName string) *auth.OAuthProvider { cfg := h.deps.Config baseURL := cfg.BaseURL switch providerName { case "google": if cfg.GoogleClientID == "" { return nil } return auth.NewGoogleProvider(cfg.GoogleClientID, cfg.GoogleClientSecret, baseURL+"/auth/google/callback") case "microsoft": if cfg.MicrosoftClientID == "" { return nil } return auth.NewMicrosoftProvider(cfg.MicrosoftClientID, cfg.MicrosoftClientSecret, cfg.MicrosoftTenantID, baseURL+"/auth/microsoft/callback") default: return nil } } func (h *OAuthHandler) Login(c *gin.Context) { providerName := c.Param("provider") // Handle Apple separately if providerName == "apple" { h.appleLogin(c) return } provider := h.getProvider(providerName) if provider == nil { c.String(http.StatusBadRequest, "Unknown provider") return } state := generateState() isSecure := strings.HasPrefix(h.deps.Config.BaseURL, "https") http.SetCookie(c.Writer, &http.Cookie{ Name: "oauth_state", Value: state, Path: "/", MaxAge: 600, // 10 minutes HttpOnly: true, Secure: isSecure, SameSite: http.SameSiteLaxMode, }) url := provider.Config.AuthCodeURL(state, oauth2.AccessTypeOffline) c.Redirect(http.StatusTemporaryRedirect, url) } func (h *OAuthHandler) Callback(c *gin.Context) { providerName := c.Param("provider") provider := h.getProvider(providerName) if provider == nil { c.String(http.StatusBadRequest, "Unknown provider") return } // Verify state stateCookie, err := c.Request.Cookie("oauth_state") if err != nil || c.Query("state") != stateCookie.Value { c.String(http.StatusBadRequest, "Invalid state parameter") return } // Clear the state cookie http.SetCookie(c.Writer, &http.Cookie{ Name: "oauth_state", Path: "/", MaxAge: -1, }) code := c.Query("code") token, err := provider.Config.Exchange(c.Request.Context(), code) if err != nil { log.Error().Err(err).Msg("oauth exchange error") c.String(http.StatusInternalServerError, "Authentication failed") return } info, err := provider.UserInfo(c.Request.Context(), token) if err != nil { log.Error().Err(err).Msg("oauth userinfo error") c.String(http.StatusInternalServerError, "Failed to get user info") return } user, err := h.deps.Auth.FindOrCreateOAuthUser(c.Request.Context(), provider.Name, info) if err != nil { if strings.Contains(err.Error(), "pending admin approval") { middleware.SetFlash(c, "info", err.Error()) c.Redirect(http.StatusSeeOther, "/login") return } log.Error().Err(err).Msg("find or create oauth user error") c.String(http.StatusInternalServerError, "Authentication failed") return } if err := h.deps.Auth.CreateSession(c.Request, c.Writer, user.ID); err != nil { log.Error().Err(err).Msg("create session error") c.String(http.StatusInternalServerError, "Authentication failed") return } c.Redirect(http.StatusSeeOther, "/tickets") } func (h *OAuthHandler) appleLogin(c *gin.Context) { cfg := h.deps.Config if cfg.AppleClientID == "" { c.String(http.StatusBadRequest, "Apple Sign In not configured") return } appleProvider, err := auth.NewAppleProvider( cfg.AppleClientID, cfg.AppleTeamID, cfg.AppleKeyID, cfg.AppleKeyPath, cfg.BaseURL+"/auth/apple/callback", ) if err != nil { log.Error().Err(err).Msg("create apple provider error") c.String(http.StatusInternalServerError, "Apple Sign In not available") return } state := generateState() isSecure := strings.HasPrefix(h.deps.Config.BaseURL, "https") http.SetCookie(c.Writer, &http.Cookie{ Name: "oauth_state", Value: state, Path: "/", MaxAge: 600, // 10 minutes HttpOnly: true, Secure: isSecure, SameSite: http.SameSiteNoneMode, // Apple uses form_post cross-origin }) url := appleProvider.Config.AuthCodeURL(state, oauth2.AccessTypeOffline, auth.AppleAuthCodeOption()) c.Redirect(http.StatusTemporaryRedirect, url) } func (h *OAuthHandler) AppleCallback(c *gin.Context) { cfg := h.deps.Config appleProvider, err := auth.NewAppleProvider( cfg.AppleClientID, cfg.AppleTeamID, cfg.AppleKeyID, cfg.AppleKeyPath, cfg.BaseURL+"/auth/apple/callback", ) if err != nil { log.Error().Err(err).Msg("create apple provider error") c.String(http.StatusInternalServerError, "Apple Sign In not available") return } // Apple uses form_post code := c.PostForm("code") state := c.PostForm("state") stateCookie, err := c.Request.Cookie("oauth_state") if err != nil || state != stateCookie.Value { c.String(http.StatusBadRequest, "Invalid state parameter") return } // Clear the state cookie http.SetCookie(c.Writer, &http.Cookie{ Name: "oauth_state", Path: "/", MaxAge: -1, }) token, err := appleProvider.ExchangeCode(c.Request.Context(), code) if err != nil { log.Error().Err(err).Msg("apple exchange error") c.String(http.StatusInternalServerError, "Authentication failed") return } info, err := appleProvider.UserInfo(c.Request.Context(), token) if err != nil { log.Error().Err(err).Msg("apple userinfo error") c.String(http.StatusInternalServerError, "Failed to get user info") return } // Apple may send user data in the form if userData := c.PostForm("user"); userData != "" { appleUser, err := auth.ParseAppleUserData(userData) if err == nil && appleUser != nil && appleUser.Name != nil { info.Name = appleUser.Name.FirstName + " " + appleUser.Name.LastName } } user, err := h.deps.Auth.FindOrCreateOAuthUser(c.Request.Context(), "apple", info) if err != nil { if strings.Contains(err.Error(), "pending admin approval") { middleware.SetFlash(c, "info", err.Error()) c.Redirect(http.StatusSeeOther, "/login") return } log.Error().Err(err).Msg("find or create apple user error") c.String(http.StatusInternalServerError, "Authentication failed") return } if err := h.deps.Auth.CreateSession(c.Request, c.Writer, user.ID); err != nil { log.Error().Err(err).Msg("create session error") c.String(http.StatusInternalServerError, "Authentication failed") return } c.Redirect(http.StatusSeeOther, "/tickets") } func generateState() string { b := make([]byte, 16) rand.Read(b) return hex.EncodeToString(b) }