forgejo-tickets/internal/handlers/public/oauth.go

225 lines
6.4 KiB
Go

package public
import (
"crypto/rand"
"encoding/hex"
"net/http"
"net/url"
"strings"
"github.com/gin-gonic/gin"
"github.com/mattnite/forgejo-tickets/internal/auth"
"github.com/rs/zerolog/log"
"golang.org/x/oauth2"
)
type OAuthHandler struct {
deps Dependencies
}
func (h *OAuthHandler) getProvider(providerName string) *auth.OAuthProvider {
cfg := h.deps.Config
baseURL := cfg.BaseURL
switch providerName {
case "google":
if cfg.GoogleClientID == "" {
return nil
}
return auth.NewGoogleProvider(cfg.GoogleClientID, cfg.GoogleClientSecret, baseURL+"/auth/google/callback")
case "microsoft":
if cfg.MicrosoftClientID == "" {
return nil
}
return auth.NewMicrosoftProvider(cfg.MicrosoftClientID, cfg.MicrosoftClientSecret, cfg.MicrosoftTenantID, baseURL+"/auth/microsoft/callback")
default:
return nil
}
}
func (h *OAuthHandler) Login(c *gin.Context) {
providerName := c.Param("provider")
// Handle Apple separately
if providerName == "apple" {
h.appleLogin(c)
return
}
provider := h.getProvider(providerName)
if provider == nil {
c.String(http.StatusBadRequest, "Unknown provider")
return
}
state := generateState()
session, _ := h.deps.SessionStore.Get(c.Request, "oauth_state")
session.Values["state"] = state
session.Values["user_id"] = "00000000-0000-0000-0000-000000000000" // placeholder for save
if err := session.Save(c.Request, c.Writer); err != nil {
log.Error().Err(err).Msg("save oauth state error")
}
url := provider.Config.AuthCodeURL(state, oauth2.AccessTypeOffline)
c.Redirect(http.StatusTemporaryRedirect, url)
}
func (h *OAuthHandler) Callback(c *gin.Context) {
providerName := c.Param("provider")
provider := h.getProvider(providerName)
if provider == nil {
c.String(http.StatusBadRequest, "Unknown provider")
return
}
// Verify state
session, _ := h.deps.SessionStore.Get(c.Request, "oauth_state")
expectedState, _ := session.Values["state"].(string)
if c.Query("state") != expectedState {
c.String(http.StatusBadRequest, "Invalid state parameter")
return
}
code := c.Query("code")
token, err := provider.Config.Exchange(c.Request.Context(), code)
if err != nil {
log.Error().Err(err).Msg("oauth exchange error")
c.String(http.StatusInternalServerError, "Authentication failed")
return
}
info, err := provider.UserInfo(c.Request.Context(), token)
if err != nil {
log.Error().Err(err).Msg("oauth userinfo error")
c.String(http.StatusInternalServerError, "Failed to get user info")
return
}
user, err := h.deps.Auth.FindOrCreateOAuthUser(c.Request.Context(), provider.Name, info)
if err != nil {
if strings.Contains(err.Error(), "pending admin approval") {
redirectURL := "/login?" + url.Values{
"flash": {err.Error()},
"flash_type": {"info"},
}.Encode()
c.Redirect(http.StatusSeeOther, redirectURL)
return
}
log.Error().Err(err).Msg("find or create oauth user error")
c.String(http.StatusInternalServerError, "Authentication failed")
return
}
if err := h.deps.Auth.CreateSession(c.Request, c.Writer, user.ID); err != nil {
log.Error().Err(err).Msg("create session error")
c.String(http.StatusInternalServerError, "Authentication failed")
return
}
c.Redirect(http.StatusSeeOther, "/tickets")
}
func (h *OAuthHandler) appleLogin(c *gin.Context) {
cfg := h.deps.Config
if cfg.AppleClientID == "" {
c.String(http.StatusBadRequest, "Apple Sign In not configured")
return
}
appleProvider, err := auth.NewAppleProvider(
cfg.AppleClientID, cfg.AppleTeamID, cfg.AppleKeyID, cfg.AppleKeyPath,
cfg.BaseURL+"/auth/apple/callback",
)
if err != nil {
log.Error().Err(err).Msg("create apple provider error")
c.String(http.StatusInternalServerError, "Apple Sign In not available")
return
}
state := generateState()
session, _ := h.deps.SessionStore.Get(c.Request, "oauth_state")
session.Values["state"] = state
session.Values["user_id"] = "00000000-0000-0000-0000-000000000000"
if err := session.Save(c.Request, c.Writer); err != nil {
log.Error().Err(err).Msg("save oauth state error")
}
url := appleProvider.Config.AuthCodeURL(state, oauth2.AccessTypeOffline, auth.AppleAuthCodeOption())
c.Redirect(http.StatusTemporaryRedirect, url)
}
func (h *OAuthHandler) AppleCallback(c *gin.Context) {
cfg := h.deps.Config
appleProvider, err := auth.NewAppleProvider(
cfg.AppleClientID, cfg.AppleTeamID, cfg.AppleKeyID, cfg.AppleKeyPath,
cfg.BaseURL+"/auth/apple/callback",
)
if err != nil {
log.Error().Err(err).Msg("create apple provider error")
c.String(http.StatusInternalServerError, "Apple Sign In not available")
return
}
// Apple uses form_post
code := c.PostForm("code")
state := c.PostForm("state")
session, _ := h.deps.SessionStore.Get(c.Request, "oauth_state")
expectedState, _ := session.Values["state"].(string)
if state != expectedState {
c.String(http.StatusBadRequest, "Invalid state parameter")
return
}
token, err := appleProvider.ExchangeCode(c.Request.Context(), code)
if err != nil {
log.Error().Err(err).Msg("apple exchange error")
c.String(http.StatusInternalServerError, "Authentication failed")
return
}
info, err := appleProvider.UserInfo(c.Request.Context(), token)
if err != nil {
log.Error().Err(err).Msg("apple userinfo error")
c.String(http.StatusInternalServerError, "Failed to get user info")
return
}
// Apple may send user data in the form
if userData := c.PostForm("user"); userData != "" {
appleUser, err := auth.ParseAppleUserData(userData)
if err == nil && appleUser != nil && appleUser.Name != nil {
info.Name = appleUser.Name.FirstName + " " + appleUser.Name.LastName
}
}
user, err := h.deps.Auth.FindOrCreateOAuthUser(c.Request.Context(), "apple", info)
if err != nil {
if strings.Contains(err.Error(), "pending admin approval") {
redirectURL := "/login?" + url.Values{
"flash": {err.Error()},
"flash_type": {"info"},
}.Encode()
c.Redirect(http.StatusSeeOther, redirectURL)
return
}
log.Error().Err(err).Msg("find or create apple user error")
c.String(http.StatusInternalServerError, "Authentication failed")
return
}
if err := h.deps.Auth.CreateSession(c.Request, c.Writer, user.ID); err != nil {
log.Error().Err(err).Msg("create session error")
c.String(http.StatusInternalServerError, "Authentication failed")
return
}
c.Redirect(http.StatusSeeOther, "/tickets")
}
func generateState() string {
b := make([]byte, 16)
rand.Read(b)
return hex.EncodeToString(b)
}