forgejo-tickets/internal/auth/oauth.go

153 lines
4.0 KiB
Go

package auth
import (
"context"
"encoding/json"
"fmt"
"io"
"github.com/mattnite/forgejo-tickets/internal/models"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
"golang.org/x/oauth2/microsoft"
)
type OAuthProvider struct {
Name string
Config *oauth2.Config
UserInfo func(ctx context.Context, token *oauth2.Token) (*OAuthUserInfo, error)
}
type OAuthUserInfo struct {
ProviderUserID string
Email string
Name string
}
func NewGoogleProvider(clientID, clientSecret, redirectURL string) *OAuthProvider {
return &OAuthProvider{
Name: "google",
Config: &oauth2.Config{
ClientID: clientID,
ClientSecret: clientSecret,
RedirectURL: redirectURL,
Scopes: []string{"openid", "email", "profile"},
Endpoint: google.Endpoint,
},
UserInfo: func(ctx context.Context, token *oauth2.Token) (*OAuthUserInfo, error) {
client := oauth2.NewClient(ctx, oauth2.StaticTokenSource(token))
resp, err := client.Get("https://www.googleapis.com/oauth2/v2/userinfo")
if err != nil {
return nil, err
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
var data struct {
ID string `json:"id"`
Email string `json:"email"`
Name string `json:"name"`
}
if err := json.Unmarshal(body, &data); err != nil {
return nil, err
}
return &OAuthUserInfo{
ProviderUserID: data.ID,
Email: data.Email,
Name: data.Name,
}, nil
},
}
}
func NewMicrosoftProvider(clientID, clientSecret, tenantID, redirectURL string) *OAuthProvider {
return &OAuthProvider{
Name: "microsoft",
Config: &oauth2.Config{
ClientID: clientID,
ClientSecret: clientSecret,
RedirectURL: redirectURL,
Scopes: []string{"openid", "email", "profile", "User.Read"},
Endpoint: microsoft.AzureADEndpoint(tenantID),
},
UserInfo: func(ctx context.Context, token *oauth2.Token) (*OAuthUserInfo, error) {
client := oauth2.NewClient(ctx, oauth2.StaticTokenSource(token))
resp, err := client.Get("https://graph.microsoft.com/v1.0/me")
if err != nil {
return nil, err
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
var data struct {
ID string `json:"id"`
Mail string `json:"mail"`
Name string `json:"displayName"`
}
if err := json.Unmarshal(body, &data); err != nil {
return nil, err
}
return &OAuthUserInfo{
ProviderUserID: data.ID,
Email: data.Mail,
Name: data.Name,
}, nil
},
}
}
func (s *Service) FindOrCreateOAuthUser(ctx context.Context, provider string, info *OAuthUserInfo) (*models.User, error) {
// Try to find existing OAuth account
var oauthAccount models.OAuthAccount
if err := s.db.WithContext(ctx).Where("provider = ? AND provider_user_id = ?", provider, info.ProviderUserID).First(&oauthAccount).Error; err == nil {
var user models.User
if err := s.db.WithContext(ctx).First(&user, "id = ?", oauthAccount.UserID).Error; err != nil {
return nil, err
}
return &user, nil
}
// Try to find existing user by email
var user models.User
if err := s.db.WithContext(ctx).Where("email = ?", info.Email).First(&user).Error; err != nil {
// Create new user
user = models.User{
Email: info.Email,
Name: info.Name,
EmailVerified: true,
}
if err := s.db.WithContext(ctx).Create(&user).Error; err != nil {
return nil, fmt.Errorf("create user: %w", err)
}
}
// Link OAuth account
oauthAccount = models.OAuthAccount{
UserID: user.ID,
Provider: provider,
ProviderUserID: info.ProviderUserID,
Email: info.Email,
}
if err := s.db.WithContext(ctx).Create(&oauthAccount).Error; err != nil {
return nil, fmt.Errorf("create oauth account: %w", err)
}
// Mark email as verified for OAuth users
if !user.EmailVerified {
s.db.WithContext(ctx).Model(&user).Update("email_verified", true)
user.EmailVerified = true
}
return &user, nil
}