225 lines
6.4 KiB
Go
225 lines
6.4 KiB
Go
package public
|
|
|
|
import (
|
|
"crypto/rand"
|
|
"encoding/hex"
|
|
"net/http"
|
|
"net/url"
|
|
"strings"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/mattnite/forgejo-tickets/internal/auth"
|
|
"github.com/rs/zerolog/log"
|
|
"golang.org/x/oauth2"
|
|
)
|
|
|
|
type OAuthHandler struct {
|
|
deps Dependencies
|
|
}
|
|
|
|
func (h *OAuthHandler) getProvider(providerName string) *auth.OAuthProvider {
|
|
cfg := h.deps.Config
|
|
baseURL := cfg.BaseURL
|
|
|
|
switch providerName {
|
|
case "google":
|
|
if cfg.GoogleClientID == "" {
|
|
return nil
|
|
}
|
|
return auth.NewGoogleProvider(cfg.GoogleClientID, cfg.GoogleClientSecret, baseURL+"/auth/google/callback")
|
|
case "microsoft":
|
|
if cfg.MicrosoftClientID == "" {
|
|
return nil
|
|
}
|
|
return auth.NewMicrosoftProvider(cfg.MicrosoftClientID, cfg.MicrosoftClientSecret, cfg.MicrosoftTenantID, baseURL+"/auth/microsoft/callback")
|
|
default:
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func (h *OAuthHandler) Login(c *gin.Context) {
|
|
providerName := c.Param("provider")
|
|
|
|
// Handle Apple separately
|
|
if providerName == "apple" {
|
|
h.appleLogin(c)
|
|
return
|
|
}
|
|
|
|
provider := h.getProvider(providerName)
|
|
if provider == nil {
|
|
c.String(http.StatusBadRequest, "Unknown provider")
|
|
return
|
|
}
|
|
|
|
state := generateState()
|
|
session, _ := h.deps.SessionStore.Get(c.Request, "oauth_state")
|
|
session.Values["state"] = state
|
|
session.Values["user_id"] = "00000000-0000-0000-0000-000000000000" // placeholder for save
|
|
if err := session.Save(c.Request, c.Writer); err != nil {
|
|
log.Error().Err(err).Msg("save oauth state error")
|
|
}
|
|
|
|
url := provider.Config.AuthCodeURL(state, oauth2.AccessTypeOffline)
|
|
c.Redirect(http.StatusTemporaryRedirect, url)
|
|
}
|
|
|
|
func (h *OAuthHandler) Callback(c *gin.Context) {
|
|
providerName := c.Param("provider")
|
|
provider := h.getProvider(providerName)
|
|
if provider == nil {
|
|
c.String(http.StatusBadRequest, "Unknown provider")
|
|
return
|
|
}
|
|
|
|
// Verify state
|
|
session, _ := h.deps.SessionStore.Get(c.Request, "oauth_state")
|
|
expectedState, _ := session.Values["state"].(string)
|
|
if c.Query("state") != expectedState {
|
|
c.String(http.StatusBadRequest, "Invalid state parameter")
|
|
return
|
|
}
|
|
|
|
code := c.Query("code")
|
|
token, err := provider.Config.Exchange(c.Request.Context(), code)
|
|
if err != nil {
|
|
log.Error().Err(err).Msg("oauth exchange error")
|
|
c.String(http.StatusInternalServerError, "Authentication failed")
|
|
return
|
|
}
|
|
|
|
info, err := provider.UserInfo(c.Request.Context(), token)
|
|
if err != nil {
|
|
log.Error().Err(err).Msg("oauth userinfo error")
|
|
c.String(http.StatusInternalServerError, "Failed to get user info")
|
|
return
|
|
}
|
|
|
|
user, err := h.deps.Auth.FindOrCreateOAuthUser(c.Request.Context(), provider.Name, info)
|
|
if err != nil {
|
|
if strings.Contains(err.Error(), "pending admin approval") {
|
|
redirectURL := "/login?" + url.Values{
|
|
"flash": {err.Error()},
|
|
"flash_type": {"info"},
|
|
}.Encode()
|
|
c.Redirect(http.StatusSeeOther, redirectURL)
|
|
return
|
|
}
|
|
log.Error().Err(err).Msg("find or create oauth user error")
|
|
c.String(http.StatusInternalServerError, "Authentication failed")
|
|
return
|
|
}
|
|
|
|
if err := h.deps.Auth.CreateSession(c.Request, c.Writer, user.ID); err != nil {
|
|
log.Error().Err(err).Msg("create session error")
|
|
c.String(http.StatusInternalServerError, "Authentication failed")
|
|
return
|
|
}
|
|
|
|
c.Redirect(http.StatusSeeOther, "/tickets")
|
|
}
|
|
|
|
func (h *OAuthHandler) appleLogin(c *gin.Context) {
|
|
cfg := h.deps.Config
|
|
if cfg.AppleClientID == "" {
|
|
c.String(http.StatusBadRequest, "Apple Sign In not configured")
|
|
return
|
|
}
|
|
|
|
appleProvider, err := auth.NewAppleProvider(
|
|
cfg.AppleClientID, cfg.AppleTeamID, cfg.AppleKeyID, cfg.AppleKeyPath,
|
|
cfg.BaseURL+"/auth/apple/callback",
|
|
)
|
|
if err != nil {
|
|
log.Error().Err(err).Msg("create apple provider error")
|
|
c.String(http.StatusInternalServerError, "Apple Sign In not available")
|
|
return
|
|
}
|
|
|
|
state := generateState()
|
|
session, _ := h.deps.SessionStore.Get(c.Request, "oauth_state")
|
|
session.Values["state"] = state
|
|
session.Values["user_id"] = "00000000-0000-0000-0000-000000000000"
|
|
if err := session.Save(c.Request, c.Writer); err != nil {
|
|
log.Error().Err(err).Msg("save oauth state error")
|
|
}
|
|
|
|
url := appleProvider.Config.AuthCodeURL(state, oauth2.AccessTypeOffline, auth.AppleAuthCodeOption())
|
|
c.Redirect(http.StatusTemporaryRedirect, url)
|
|
}
|
|
|
|
func (h *OAuthHandler) AppleCallback(c *gin.Context) {
|
|
cfg := h.deps.Config
|
|
appleProvider, err := auth.NewAppleProvider(
|
|
cfg.AppleClientID, cfg.AppleTeamID, cfg.AppleKeyID, cfg.AppleKeyPath,
|
|
cfg.BaseURL+"/auth/apple/callback",
|
|
)
|
|
if err != nil {
|
|
log.Error().Err(err).Msg("create apple provider error")
|
|
c.String(http.StatusInternalServerError, "Apple Sign In not available")
|
|
return
|
|
}
|
|
|
|
// Apple uses form_post
|
|
code := c.PostForm("code")
|
|
state := c.PostForm("state")
|
|
|
|
session, _ := h.deps.SessionStore.Get(c.Request, "oauth_state")
|
|
expectedState, _ := session.Values["state"].(string)
|
|
if state != expectedState {
|
|
c.String(http.StatusBadRequest, "Invalid state parameter")
|
|
return
|
|
}
|
|
|
|
token, err := appleProvider.ExchangeCode(c.Request.Context(), code)
|
|
if err != nil {
|
|
log.Error().Err(err).Msg("apple exchange error")
|
|
c.String(http.StatusInternalServerError, "Authentication failed")
|
|
return
|
|
}
|
|
|
|
info, err := appleProvider.UserInfo(c.Request.Context(), token)
|
|
if err != nil {
|
|
log.Error().Err(err).Msg("apple userinfo error")
|
|
c.String(http.StatusInternalServerError, "Failed to get user info")
|
|
return
|
|
}
|
|
|
|
// Apple may send user data in the form
|
|
if userData := c.PostForm("user"); userData != "" {
|
|
appleUser, err := auth.ParseAppleUserData(userData)
|
|
if err == nil && appleUser != nil && appleUser.Name != nil {
|
|
info.Name = appleUser.Name.FirstName + " " + appleUser.Name.LastName
|
|
}
|
|
}
|
|
|
|
user, err := h.deps.Auth.FindOrCreateOAuthUser(c.Request.Context(), "apple", info)
|
|
if err != nil {
|
|
if strings.Contains(err.Error(), "pending admin approval") {
|
|
redirectURL := "/login?" + url.Values{
|
|
"flash": {err.Error()},
|
|
"flash_type": {"info"},
|
|
}.Encode()
|
|
c.Redirect(http.StatusSeeOther, redirectURL)
|
|
return
|
|
}
|
|
log.Error().Err(err).Msg("find or create apple user error")
|
|
c.String(http.StatusInternalServerError, "Authentication failed")
|
|
return
|
|
}
|
|
|
|
if err := h.deps.Auth.CreateSession(c.Request, c.Writer, user.ID); err != nil {
|
|
log.Error().Err(err).Msg("create session error")
|
|
c.String(http.StatusInternalServerError, "Authentication failed")
|
|
return
|
|
}
|
|
|
|
c.Redirect(http.StatusSeeOther, "/tickets")
|
|
}
|
|
|
|
func generateState() string {
|
|
b := make([]byte, 16)
|
|
rand.Read(b)
|
|
return hex.EncodeToString(b)
|
|
}
|