153 lines
4.0 KiB
Go
153 lines
4.0 KiB
Go
package auth
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
|
|
"github.com/mattnite/forgejo-tickets/internal/models"
|
|
"golang.org/x/oauth2"
|
|
"golang.org/x/oauth2/google"
|
|
"golang.org/x/oauth2/microsoft"
|
|
)
|
|
|
|
type OAuthProvider struct {
|
|
Name string
|
|
Config *oauth2.Config
|
|
UserInfo func(ctx context.Context, token *oauth2.Token) (*OAuthUserInfo, error)
|
|
}
|
|
|
|
type OAuthUserInfo struct {
|
|
ProviderUserID string
|
|
Email string
|
|
Name string
|
|
}
|
|
|
|
func NewGoogleProvider(clientID, clientSecret, redirectURL string) *OAuthProvider {
|
|
return &OAuthProvider{
|
|
Name: "google",
|
|
Config: &oauth2.Config{
|
|
ClientID: clientID,
|
|
ClientSecret: clientSecret,
|
|
RedirectURL: redirectURL,
|
|
Scopes: []string{"openid", "email", "profile"},
|
|
Endpoint: google.Endpoint,
|
|
},
|
|
UserInfo: func(ctx context.Context, token *oauth2.Token) (*OAuthUserInfo, error) {
|
|
client := oauth2.NewClient(ctx, oauth2.StaticTokenSource(token))
|
|
resp, err := client.Get("https://www.googleapis.com/oauth2/v2/userinfo")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var data struct {
|
|
ID string `json:"id"`
|
|
Email string `json:"email"`
|
|
Name string `json:"name"`
|
|
}
|
|
if err := json.Unmarshal(body, &data); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &OAuthUserInfo{
|
|
ProviderUserID: data.ID,
|
|
Email: data.Email,
|
|
Name: data.Name,
|
|
}, nil
|
|
},
|
|
}
|
|
}
|
|
|
|
func NewMicrosoftProvider(clientID, clientSecret, tenantID, redirectURL string) *OAuthProvider {
|
|
return &OAuthProvider{
|
|
Name: "microsoft",
|
|
Config: &oauth2.Config{
|
|
ClientID: clientID,
|
|
ClientSecret: clientSecret,
|
|
RedirectURL: redirectURL,
|
|
Scopes: []string{"openid", "email", "profile", "User.Read"},
|
|
Endpoint: microsoft.AzureADEndpoint(tenantID),
|
|
},
|
|
UserInfo: func(ctx context.Context, token *oauth2.Token) (*OAuthUserInfo, error) {
|
|
client := oauth2.NewClient(ctx, oauth2.StaticTokenSource(token))
|
|
resp, err := client.Get("https://graph.microsoft.com/v1.0/me")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var data struct {
|
|
ID string `json:"id"`
|
|
Mail string `json:"mail"`
|
|
Name string `json:"displayName"`
|
|
}
|
|
if err := json.Unmarshal(body, &data); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &OAuthUserInfo{
|
|
ProviderUserID: data.ID,
|
|
Email: data.Mail,
|
|
Name: data.Name,
|
|
}, nil
|
|
},
|
|
}
|
|
}
|
|
|
|
func (s *Service) FindOrCreateOAuthUser(ctx context.Context, provider string, info *OAuthUserInfo) (*models.User, error) {
|
|
// Try to find existing OAuth account
|
|
var oauthAccount models.OAuthAccount
|
|
if err := s.db.WithContext(ctx).Where("provider = ? AND provider_user_id = ?", provider, info.ProviderUserID).First(&oauthAccount).Error; err == nil {
|
|
var user models.User
|
|
if err := s.db.WithContext(ctx).First(&user, "id = ?", oauthAccount.UserID).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
return &user, nil
|
|
}
|
|
|
|
// Try to find existing user by email
|
|
var user models.User
|
|
if err := s.db.WithContext(ctx).Where("email = ?", info.Email).First(&user).Error; err != nil {
|
|
// Create new user
|
|
user = models.User{
|
|
Email: info.Email,
|
|
Name: info.Name,
|
|
EmailVerified: true,
|
|
}
|
|
if err := s.db.WithContext(ctx).Create(&user).Error; err != nil {
|
|
return nil, fmt.Errorf("create user: %w", err)
|
|
}
|
|
}
|
|
|
|
// Link OAuth account
|
|
oauthAccount = models.OAuthAccount{
|
|
UserID: user.ID,
|
|
Provider: provider,
|
|
ProviderUserID: info.ProviderUserID,
|
|
Email: info.Email,
|
|
}
|
|
if err := s.db.WithContext(ctx).Create(&oauthAccount).Error; err != nil {
|
|
return nil, fmt.Errorf("create oauth account: %w", err)
|
|
}
|
|
|
|
// Mark email as verified for OAuth users
|
|
if !user.EmailVerified {
|
|
s.db.WithContext(ctx).Model(&user).Update("email_verified", true)
|
|
user.EmailVerified = true
|
|
}
|
|
|
|
return &user, nil
|
|
}
|