package public import ( "crypto/rand" "encoding/hex" "net/http" "net/url" "strings" "github.com/gin-gonic/gin" "github.com/mattnite/forgejo-tickets/internal/auth" "github.com/rs/zerolog/log" "golang.org/x/oauth2" ) type OAuthHandler struct { deps Dependencies } func (h *OAuthHandler) getProvider(providerName string) *auth.OAuthProvider { cfg := h.deps.Config baseURL := cfg.BaseURL switch providerName { case "google": if cfg.GoogleClientID == "" { return nil } return auth.NewGoogleProvider(cfg.GoogleClientID, cfg.GoogleClientSecret, baseURL+"/auth/google/callback") case "microsoft": if cfg.MicrosoftClientID == "" { return nil } return auth.NewMicrosoftProvider(cfg.MicrosoftClientID, cfg.MicrosoftClientSecret, cfg.MicrosoftTenantID, baseURL+"/auth/microsoft/callback") default: return nil } } func (h *OAuthHandler) Login(c *gin.Context) { providerName := c.Param("provider") // Handle Apple separately if providerName == "apple" { h.appleLogin(c) return } provider := h.getProvider(providerName) if provider == nil { c.String(http.StatusBadRequest, "Unknown provider") return } state := generateState() session, _ := h.deps.SessionStore.Get(c.Request, "oauth_state") session.Values["state"] = state session.Values["user_id"] = "00000000-0000-0000-0000-000000000000" // placeholder for save if err := session.Save(c.Request, c.Writer); err != nil { log.Error().Err(err).Msg("save oauth state error") } url := provider.Config.AuthCodeURL(state, oauth2.AccessTypeOffline) c.Redirect(http.StatusTemporaryRedirect, url) } func (h *OAuthHandler) Callback(c *gin.Context) { providerName := c.Param("provider") provider := h.getProvider(providerName) if provider == nil { c.String(http.StatusBadRequest, "Unknown provider") return } // Verify state session, _ := h.deps.SessionStore.Get(c.Request, "oauth_state") expectedState, _ := session.Values["state"].(string) if c.Query("state") != expectedState { c.String(http.StatusBadRequest, "Invalid state parameter") return } code := c.Query("code") token, err := provider.Config.Exchange(c.Request.Context(), code) if err != nil { log.Error().Err(err).Msg("oauth exchange error") c.String(http.StatusInternalServerError, "Authentication failed") return } info, err := provider.UserInfo(c.Request.Context(), token) if err != nil { log.Error().Err(err).Msg("oauth userinfo error") c.String(http.StatusInternalServerError, "Failed to get user info") return } user, err := h.deps.Auth.FindOrCreateOAuthUser(c.Request.Context(), provider.Name, info) if err != nil { if strings.Contains(err.Error(), "pending admin approval") { redirectURL := "/login?" + url.Values{ "flash": {err.Error()}, "flash_type": {"info"}, }.Encode() c.Redirect(http.StatusSeeOther, redirectURL) return } log.Error().Err(err).Msg("find or create oauth user error") c.String(http.StatusInternalServerError, "Authentication failed") return } if err := h.deps.Auth.CreateSession(c.Request, c.Writer, user.ID); err != nil { log.Error().Err(err).Msg("create session error") c.String(http.StatusInternalServerError, "Authentication failed") return } c.Redirect(http.StatusSeeOther, "/tickets") } func (h *OAuthHandler) appleLogin(c *gin.Context) { cfg := h.deps.Config if cfg.AppleClientID == "" { c.String(http.StatusBadRequest, "Apple Sign In not configured") return } appleProvider, err := auth.NewAppleProvider( cfg.AppleClientID, cfg.AppleTeamID, cfg.AppleKeyID, cfg.AppleKeyPath, cfg.BaseURL+"/auth/apple/callback", ) if err != nil { log.Error().Err(err).Msg("create apple provider error") c.String(http.StatusInternalServerError, "Apple Sign In not available") return } state := generateState() session, _ := h.deps.SessionStore.Get(c.Request, "oauth_state") session.Values["state"] = state session.Values["user_id"] = "00000000-0000-0000-0000-000000000000" if err := session.Save(c.Request, c.Writer); err != nil { log.Error().Err(err).Msg("save oauth state error") } url := appleProvider.Config.AuthCodeURL(state, oauth2.AccessTypeOffline, auth.AppleAuthCodeOption()) c.Redirect(http.StatusTemporaryRedirect, url) } func (h *OAuthHandler) AppleCallback(c *gin.Context) { cfg := h.deps.Config appleProvider, err := auth.NewAppleProvider( cfg.AppleClientID, cfg.AppleTeamID, cfg.AppleKeyID, cfg.AppleKeyPath, cfg.BaseURL+"/auth/apple/callback", ) if err != nil { log.Error().Err(err).Msg("create apple provider error") c.String(http.StatusInternalServerError, "Apple Sign In not available") return } // Apple uses form_post code := c.PostForm("code") state := c.PostForm("state") session, _ := h.deps.SessionStore.Get(c.Request, "oauth_state") expectedState, _ := session.Values["state"].(string) if state != expectedState { c.String(http.StatusBadRequest, "Invalid state parameter") return } token, err := appleProvider.ExchangeCode(c.Request.Context(), code) if err != nil { log.Error().Err(err).Msg("apple exchange error") c.String(http.StatusInternalServerError, "Authentication failed") return } info, err := appleProvider.UserInfo(c.Request.Context(), token) if err != nil { log.Error().Err(err).Msg("apple userinfo error") c.String(http.StatusInternalServerError, "Failed to get user info") return } // Apple may send user data in the form if userData := c.PostForm("user"); userData != "" { appleUser, err := auth.ParseAppleUserData(userData) if err == nil && appleUser != nil && appleUser.Name != nil { info.Name = appleUser.Name.FirstName + " " + appleUser.Name.LastName } } user, err := h.deps.Auth.FindOrCreateOAuthUser(c.Request.Context(), "apple", info) if err != nil { if strings.Contains(err.Error(), "pending admin approval") { redirectURL := "/login?" + url.Values{ "flash": {err.Error()}, "flash_type": {"info"}, }.Encode() c.Redirect(http.StatusSeeOther, redirectURL) return } log.Error().Err(err).Msg("find or create apple user error") c.String(http.StatusInternalServerError, "Authentication failed") return } if err := h.deps.Auth.CreateSession(c.Request, c.Writer, user.ID); err != nil { log.Error().Err(err).Msg("create session error") c.String(http.StatusInternalServerError, "Authentication failed") return } c.Redirect(http.StatusSeeOther, "/tickets") } func generateState() string { b := make([]byte, 16) rand.Read(b) return hex.EncodeToString(b) }