243 lines
6.5 KiB
Go
243 lines
6.5 KiB
Go
package public
|
|
|
|
import (
|
|
"crypto/rand"
|
|
"encoding/hex"
|
|
"net/http"
|
|
"net/url"
|
|
"strings"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/mattnite/forgejo-tickets/internal/auth"
|
|
"github.com/rs/zerolog/log"
|
|
"golang.org/x/oauth2"
|
|
)
|
|
|
|
type OAuthHandler struct {
|
|
deps Dependencies
|
|
}
|
|
|
|
func (h *OAuthHandler) getProvider(providerName string) *auth.OAuthProvider {
|
|
cfg := h.deps.Config
|
|
baseURL := cfg.BaseURL
|
|
|
|
switch providerName {
|
|
case "google":
|
|
if cfg.GoogleClientID == "" {
|
|
return nil
|
|
}
|
|
return auth.NewGoogleProvider(cfg.GoogleClientID, cfg.GoogleClientSecret, baseURL+"/auth/google/callback")
|
|
case "microsoft":
|
|
if cfg.MicrosoftClientID == "" {
|
|
return nil
|
|
}
|
|
return auth.NewMicrosoftProvider(cfg.MicrosoftClientID, cfg.MicrosoftClientSecret, cfg.MicrosoftTenantID, baseURL+"/auth/microsoft/callback")
|
|
default:
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func (h *OAuthHandler) Login(c *gin.Context) {
|
|
providerName := c.Param("provider")
|
|
|
|
// Handle Apple separately
|
|
if providerName == "apple" {
|
|
h.appleLogin(c)
|
|
return
|
|
}
|
|
|
|
provider := h.getProvider(providerName)
|
|
if provider == nil {
|
|
c.String(http.StatusBadRequest, "Unknown provider")
|
|
return
|
|
}
|
|
|
|
state := generateState()
|
|
isSecure := strings.HasPrefix(h.deps.Config.BaseURL, "https")
|
|
http.SetCookie(c.Writer, &http.Cookie{
|
|
Name: "oauth_state",
|
|
Value: state,
|
|
Path: "/",
|
|
MaxAge: 600, // 10 minutes
|
|
HttpOnly: true,
|
|
Secure: isSecure,
|
|
SameSite: http.SameSiteLaxMode,
|
|
})
|
|
|
|
url := provider.Config.AuthCodeURL(state, oauth2.AccessTypeOffline)
|
|
c.Redirect(http.StatusTemporaryRedirect, url)
|
|
}
|
|
|
|
func (h *OAuthHandler) Callback(c *gin.Context) {
|
|
providerName := c.Param("provider")
|
|
provider := h.getProvider(providerName)
|
|
if provider == nil {
|
|
c.String(http.StatusBadRequest, "Unknown provider")
|
|
return
|
|
}
|
|
|
|
// Verify state
|
|
stateCookie, err := c.Request.Cookie("oauth_state")
|
|
if err != nil || c.Query("state") != stateCookie.Value {
|
|
c.String(http.StatusBadRequest, "Invalid state parameter")
|
|
return
|
|
}
|
|
// Clear the state cookie
|
|
http.SetCookie(c.Writer, &http.Cookie{
|
|
Name: "oauth_state",
|
|
Path: "/",
|
|
MaxAge: -1,
|
|
})
|
|
|
|
code := c.Query("code")
|
|
token, err := provider.Config.Exchange(c.Request.Context(), code)
|
|
if err != nil {
|
|
log.Error().Err(err).Msg("oauth exchange error")
|
|
c.String(http.StatusInternalServerError, "Authentication failed")
|
|
return
|
|
}
|
|
|
|
info, err := provider.UserInfo(c.Request.Context(), token)
|
|
if err != nil {
|
|
log.Error().Err(err).Msg("oauth userinfo error")
|
|
c.String(http.StatusInternalServerError, "Failed to get user info")
|
|
return
|
|
}
|
|
|
|
user, err := h.deps.Auth.FindOrCreateOAuthUser(c.Request.Context(), provider.Name, info)
|
|
if err != nil {
|
|
if strings.Contains(err.Error(), "pending admin approval") {
|
|
redirectURL := "/login?" + url.Values{
|
|
"flash": {err.Error()},
|
|
"flash_type": {"info"},
|
|
}.Encode()
|
|
c.Redirect(http.StatusSeeOther, redirectURL)
|
|
return
|
|
}
|
|
log.Error().Err(err).Msg("find or create oauth user error")
|
|
c.String(http.StatusInternalServerError, "Authentication failed")
|
|
return
|
|
}
|
|
|
|
if err := h.deps.Auth.CreateSession(c.Request, c.Writer, user.ID); err != nil {
|
|
log.Error().Err(err).Msg("create session error")
|
|
c.String(http.StatusInternalServerError, "Authentication failed")
|
|
return
|
|
}
|
|
|
|
c.Redirect(http.StatusSeeOther, "/tickets")
|
|
}
|
|
|
|
func (h *OAuthHandler) appleLogin(c *gin.Context) {
|
|
cfg := h.deps.Config
|
|
if cfg.AppleClientID == "" {
|
|
c.String(http.StatusBadRequest, "Apple Sign In not configured")
|
|
return
|
|
}
|
|
|
|
appleProvider, err := auth.NewAppleProvider(
|
|
cfg.AppleClientID, cfg.AppleTeamID, cfg.AppleKeyID, cfg.AppleKeyPath,
|
|
cfg.BaseURL+"/auth/apple/callback",
|
|
)
|
|
if err != nil {
|
|
log.Error().Err(err).Msg("create apple provider error")
|
|
c.String(http.StatusInternalServerError, "Apple Sign In not available")
|
|
return
|
|
}
|
|
|
|
state := generateState()
|
|
isSecure := strings.HasPrefix(h.deps.Config.BaseURL, "https")
|
|
http.SetCookie(c.Writer, &http.Cookie{
|
|
Name: "oauth_state",
|
|
Value: state,
|
|
Path: "/",
|
|
MaxAge: 600, // 10 minutes
|
|
HttpOnly: true,
|
|
Secure: isSecure,
|
|
SameSite: http.SameSiteNoneMode, // Apple uses form_post cross-origin
|
|
})
|
|
|
|
url := appleProvider.Config.AuthCodeURL(state, oauth2.AccessTypeOffline, auth.AppleAuthCodeOption())
|
|
c.Redirect(http.StatusTemporaryRedirect, url)
|
|
}
|
|
|
|
func (h *OAuthHandler) AppleCallback(c *gin.Context) {
|
|
cfg := h.deps.Config
|
|
appleProvider, err := auth.NewAppleProvider(
|
|
cfg.AppleClientID, cfg.AppleTeamID, cfg.AppleKeyID, cfg.AppleKeyPath,
|
|
cfg.BaseURL+"/auth/apple/callback",
|
|
)
|
|
if err != nil {
|
|
log.Error().Err(err).Msg("create apple provider error")
|
|
c.String(http.StatusInternalServerError, "Apple Sign In not available")
|
|
return
|
|
}
|
|
|
|
// Apple uses form_post
|
|
code := c.PostForm("code")
|
|
state := c.PostForm("state")
|
|
|
|
stateCookie, err := c.Request.Cookie("oauth_state")
|
|
if err != nil || state != stateCookie.Value {
|
|
c.String(http.StatusBadRequest, "Invalid state parameter")
|
|
return
|
|
}
|
|
// Clear the state cookie
|
|
http.SetCookie(c.Writer, &http.Cookie{
|
|
Name: "oauth_state",
|
|
Path: "/",
|
|
MaxAge: -1,
|
|
})
|
|
|
|
token, err := appleProvider.ExchangeCode(c.Request.Context(), code)
|
|
if err != nil {
|
|
log.Error().Err(err).Msg("apple exchange error")
|
|
c.String(http.StatusInternalServerError, "Authentication failed")
|
|
return
|
|
}
|
|
|
|
info, err := appleProvider.UserInfo(c.Request.Context(), token)
|
|
if err != nil {
|
|
log.Error().Err(err).Msg("apple userinfo error")
|
|
c.String(http.StatusInternalServerError, "Failed to get user info")
|
|
return
|
|
}
|
|
|
|
// Apple may send user data in the form
|
|
if userData := c.PostForm("user"); userData != "" {
|
|
appleUser, err := auth.ParseAppleUserData(userData)
|
|
if err == nil && appleUser != nil && appleUser.Name != nil {
|
|
info.Name = appleUser.Name.FirstName + " " + appleUser.Name.LastName
|
|
}
|
|
}
|
|
|
|
user, err := h.deps.Auth.FindOrCreateOAuthUser(c.Request.Context(), "apple", info)
|
|
if err != nil {
|
|
if strings.Contains(err.Error(), "pending admin approval") {
|
|
redirectURL := "/login?" + url.Values{
|
|
"flash": {err.Error()},
|
|
"flash_type": {"info"},
|
|
}.Encode()
|
|
c.Redirect(http.StatusSeeOther, redirectURL)
|
|
return
|
|
}
|
|
log.Error().Err(err).Msg("find or create apple user error")
|
|
c.String(http.StatusInternalServerError, "Authentication failed")
|
|
return
|
|
}
|
|
|
|
if err := h.deps.Auth.CreateSession(c.Request, c.Writer, user.ID); err != nil {
|
|
log.Error().Err(err).Msg("create session error")
|
|
c.String(http.StatusInternalServerError, "Authentication failed")
|
|
return
|
|
}
|
|
|
|
c.Redirect(http.StatusSeeOther, "/tickets")
|
|
}
|
|
|
|
func generateState() string {
|
|
b := make([]byte, 16)
|
|
rand.Read(b)
|
|
return hex.EncodeToString(b)
|
|
}
|