154 lines
3.7 KiB
Go
154 lines
3.7 KiB
Go
package auth
|
|
|
|
import (
|
|
"context"
|
|
"crypto/x509"
|
|
"encoding/json"
|
|
"encoding/pem"
|
|
"fmt"
|
|
"io"
|
|
"os"
|
|
"time"
|
|
|
|
"github.com/golang-jwt/jwt/v5"
|
|
"golang.org/x/oauth2"
|
|
)
|
|
|
|
type AppleProvider struct {
|
|
OAuthProvider
|
|
teamID string
|
|
keyID string
|
|
keyPath string
|
|
}
|
|
|
|
func NewAppleProvider(clientID, teamID, keyID, keyPath, redirectURL string) (*AppleProvider, error) {
|
|
p := &AppleProvider{
|
|
OAuthProvider: OAuthProvider{
|
|
Name: "apple",
|
|
Config: &oauth2.Config{
|
|
ClientID: clientID,
|
|
RedirectURL: redirectURL,
|
|
Scopes: []string{"name", "email"},
|
|
Endpoint: oauth2.Endpoint{
|
|
AuthURL: "https://appleid.apple.com/auth/authorize",
|
|
TokenURL: "https://appleid.apple.com/auth/token",
|
|
},
|
|
},
|
|
},
|
|
teamID: teamID,
|
|
keyID: keyID,
|
|
keyPath: keyPath,
|
|
}
|
|
|
|
p.OAuthProvider.UserInfo = p.getUserInfo
|
|
return p, nil
|
|
}
|
|
|
|
func (p *AppleProvider) GenerateClientSecret() (string, error) {
|
|
keyData, err := os.ReadFile(p.keyPath)
|
|
if err != nil {
|
|
return "", fmt.Errorf("read apple key: %w", err)
|
|
}
|
|
|
|
block, _ := pem.Decode(keyData)
|
|
if block == nil {
|
|
return "", fmt.Errorf("failed to decode PEM block")
|
|
}
|
|
|
|
key, err := x509.ParsePKCS8PrivateKey(block.Bytes)
|
|
if err != nil {
|
|
return "", fmt.Errorf("parse private key: %w", err)
|
|
}
|
|
|
|
now := time.Now()
|
|
claims := jwt.MapClaims{
|
|
"iss": p.teamID,
|
|
"iat": now.Unix(),
|
|
"exp": now.Add(5 * time.Minute).Unix(),
|
|
"aud": "https://appleid.apple.com",
|
|
"sub": p.Config.ClientID,
|
|
}
|
|
|
|
token := jwt.NewWithClaims(jwt.SigningMethodES256, claims)
|
|
token.Header["kid"] = p.keyID
|
|
|
|
return token.SignedString(key)
|
|
}
|
|
|
|
func (p *AppleProvider) getUserInfo(ctx context.Context, token *oauth2.Token) (*OAuthUserInfo, error) {
|
|
idToken, ok := token.Extra("id_token").(string)
|
|
if !ok {
|
|
return nil, fmt.Errorf("missing id_token")
|
|
}
|
|
|
|
// Parse without verification since we already got the token from Apple
|
|
parser := jwt.NewParser(jwt.WithoutClaimsValidation())
|
|
parsed, _, err := parser.ParseUnverified(idToken, jwt.MapClaims{})
|
|
if err != nil {
|
|
return nil, fmt.Errorf("parse id_token: %w", err)
|
|
}
|
|
|
|
claims, ok := parsed.Claims.(jwt.MapClaims)
|
|
if !ok {
|
|
return nil, fmt.Errorf("invalid claims")
|
|
}
|
|
|
|
sub, _ := claims["sub"].(string)
|
|
email, _ := claims["email"].(string)
|
|
|
|
return &OAuthUserInfo{
|
|
ProviderUserID: sub,
|
|
Email: email,
|
|
Name: email, // Apple may not provide name in id_token
|
|
}, nil
|
|
}
|
|
|
|
func (p *AppleProvider) ExchangeCode(ctx context.Context, code string) (*oauth2.Token, error) {
|
|
secret, err := p.GenerateClientSecret()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
p.Config.ClientSecret = secret
|
|
return p.Config.Exchange(ctx, code)
|
|
}
|
|
|
|
// ParseAppleUserData parses the user data from Apple's form_post callback.
|
|
// Apple only sends user data on the first authorization.
|
|
type AppleUserData struct {
|
|
Name *struct {
|
|
FirstName string `json:"firstName"`
|
|
LastName string `json:"lastName"`
|
|
} `json:"name"`
|
|
Email string `json:"email"`
|
|
}
|
|
|
|
func ParseAppleUserData(data string) (*AppleUserData, error) {
|
|
if data == "" {
|
|
return nil, nil
|
|
}
|
|
var ud AppleUserData
|
|
if err := json.Unmarshal([]byte(data), &ud); err != nil {
|
|
return nil, err
|
|
}
|
|
return &ud, nil
|
|
}
|
|
|
|
// Ensure the id_token response is read
|
|
func init() {
|
|
// Register a custom AuthCodeOption to request response_mode=form_post
|
|
_ = oauth2.SetAuthURLParam("response_mode", "form_post")
|
|
}
|
|
|
|
// AppleAuthCodeOption returns the extra auth URL param for form_post mode.
|
|
func AppleAuthCodeOption() oauth2.AuthCodeOption {
|
|
return oauth2.SetAuthURLParam("response_mode", "form_post")
|
|
}
|
|
|
|
// ReadAppleFormPost reads an Apple Sign In form_post callback.
|
|
func ReadAppleFormPost(body io.Reader) (code, state, userData string, err error) {
|
|
// Apple sends form_post so we need to read from the body
|
|
// This is handled by r.FormValue() in the handler
|
|
return "", "", "", nil
|
|
}
|