208 lines
4.9 KiB
Go
208 lines
4.9 KiB
Go
package auth
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rsa"
|
|
"crypto/x509"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"encoding/pem"
|
|
"fmt"
|
|
"io"
|
|
"math/big"
|
|
"net/http"
|
|
"os"
|
|
"time"
|
|
|
|
"github.com/golang-jwt/jwt/v5"
|
|
"golang.org/x/oauth2"
|
|
)
|
|
|
|
type AppleProvider struct {
|
|
OAuthProvider
|
|
teamID string
|
|
keyID string
|
|
keyPath string
|
|
}
|
|
|
|
func NewAppleProvider(clientID, teamID, keyID, keyPath, redirectURL string) (*AppleProvider, error) {
|
|
p := &AppleProvider{
|
|
OAuthProvider: OAuthProvider{
|
|
Name: "apple",
|
|
Config: &oauth2.Config{
|
|
ClientID: clientID,
|
|
RedirectURL: redirectURL,
|
|
Scopes: []string{"name", "email"},
|
|
Endpoint: oauth2.Endpoint{
|
|
AuthURL: "https://appleid.apple.com/auth/authorize",
|
|
TokenURL: "https://appleid.apple.com/auth/token",
|
|
},
|
|
},
|
|
},
|
|
teamID: teamID,
|
|
keyID: keyID,
|
|
keyPath: keyPath,
|
|
}
|
|
|
|
p.OAuthProvider.UserInfo = p.getUserInfo
|
|
return p, nil
|
|
}
|
|
|
|
func (p *AppleProvider) GenerateClientSecret() (string, error) {
|
|
keyData, err := os.ReadFile(p.keyPath)
|
|
if err != nil {
|
|
return "", fmt.Errorf("read apple key: %w", err)
|
|
}
|
|
|
|
block, _ := pem.Decode(keyData)
|
|
if block == nil {
|
|
return "", fmt.Errorf("failed to decode PEM block")
|
|
}
|
|
|
|
key, err := x509.ParsePKCS8PrivateKey(block.Bytes)
|
|
if err != nil {
|
|
return "", fmt.Errorf("parse private key: %w", err)
|
|
}
|
|
|
|
now := time.Now()
|
|
claims := jwt.MapClaims{
|
|
"iss": p.teamID,
|
|
"iat": now.Unix(),
|
|
"exp": now.Add(5 * time.Minute).Unix(),
|
|
"aud": "https://appleid.apple.com",
|
|
"sub": p.Config.ClientID,
|
|
}
|
|
|
|
token := jwt.NewWithClaims(jwt.SigningMethodES256, claims)
|
|
token.Header["kid"] = p.keyID
|
|
|
|
return token.SignedString(key)
|
|
}
|
|
|
|
func (p *AppleProvider) getUserInfo(ctx context.Context, token *oauth2.Token) (*OAuthUserInfo, error) {
|
|
idToken, ok := token.Extra("id_token").(string)
|
|
if !ok {
|
|
return nil, fmt.Errorf("missing id_token")
|
|
}
|
|
|
|
// Fetch Apple's JWKS
|
|
resp, err := http.Get("https://appleid.apple.com/auth/keys")
|
|
if err != nil {
|
|
return nil, fmt.Errorf("fetch apple JWKS: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
var jwks struct {
|
|
Keys []struct {
|
|
Kty string `json:"kty"`
|
|
Kid string `json:"kid"`
|
|
Use string `json:"use"`
|
|
Alg string `json:"alg"`
|
|
N string `json:"n"`
|
|
E string `json:"e"`
|
|
} `json:"keys"`
|
|
}
|
|
if err := json.NewDecoder(resp.Body).Decode(&jwks); err != nil {
|
|
return nil, fmt.Errorf("decode apple JWKS: %w", err)
|
|
}
|
|
|
|
// Parse and verify the token
|
|
parsed, err := jwt.Parse(idToken, func(t *jwt.Token) (interface{}, error) {
|
|
kid, ok := t.Header["kid"].(string)
|
|
if !ok {
|
|
return nil, fmt.Errorf("missing kid header")
|
|
}
|
|
|
|
for _, key := range jwks.Keys {
|
|
if key.Kid == kid {
|
|
// Decode RSA public key from JWK
|
|
nBytes, err := base64.RawURLEncoding.DecodeString(key.N)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("decode key N: %w", err)
|
|
}
|
|
eBytes, err := base64.RawURLEncoding.DecodeString(key.E)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("decode key E: %w", err)
|
|
}
|
|
|
|
e := 0
|
|
for _, b := range eBytes {
|
|
e = e*256 + int(b)
|
|
}
|
|
|
|
return &rsa.PublicKey{
|
|
N: new(big.Int).SetBytes(nBytes),
|
|
E: e,
|
|
}, nil
|
|
}
|
|
}
|
|
return nil, fmt.Errorf("key %s not found in JWKS", kid)
|
|
})
|
|
if err != nil {
|
|
return nil, fmt.Errorf("verify id_token: %w", err)
|
|
}
|
|
|
|
claims, ok := parsed.Claims.(jwt.MapClaims)
|
|
if !ok {
|
|
return nil, fmt.Errorf("invalid claims")
|
|
}
|
|
|
|
sub, _ := claims["sub"].(string)
|
|
email, _ := claims["email"].(string)
|
|
|
|
return &OAuthUserInfo{
|
|
ProviderUserID: sub,
|
|
Email: email,
|
|
Name: email,
|
|
}, nil
|
|
}
|
|
|
|
func (p *AppleProvider) ExchangeCode(ctx context.Context, code string) (*oauth2.Token, error) {
|
|
secret, err := p.GenerateClientSecret()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
p.Config.ClientSecret = secret
|
|
return p.Config.Exchange(ctx, code)
|
|
}
|
|
|
|
// ParseAppleUserData parses the user data from Apple's form_post callback.
|
|
// Apple only sends user data on the first authorization.
|
|
type AppleUserData struct {
|
|
Name *struct {
|
|
FirstName string `json:"firstName"`
|
|
LastName string `json:"lastName"`
|
|
} `json:"name"`
|
|
Email string `json:"email"`
|
|
}
|
|
|
|
func ParseAppleUserData(data string) (*AppleUserData, error) {
|
|
if data == "" {
|
|
return nil, nil
|
|
}
|
|
var ud AppleUserData
|
|
if err := json.Unmarshal([]byte(data), &ud); err != nil {
|
|
return nil, err
|
|
}
|
|
return &ud, nil
|
|
}
|
|
|
|
// Ensure the id_token response is read
|
|
func init() {
|
|
// Register a custom AuthCodeOption to request response_mode=form_post
|
|
_ = oauth2.SetAuthURLParam("response_mode", "form_post")
|
|
}
|
|
|
|
// AppleAuthCodeOption returns the extra auth URL param for form_post mode.
|
|
func AppleAuthCodeOption() oauth2.AuthCodeOption {
|
|
return oauth2.SetAuthURLParam("response_mode", "form_post")
|
|
}
|
|
|
|
// ReadAppleFormPost reads an Apple Sign In form_post callback.
|
|
func ReadAppleFormPost(body io.Reader) (code, state, userData string, err error) {
|
|
// Apple sends form_post so we need to read from the body
|
|
// This is handled by r.FormValue() in the handler
|
|
return "", "", "", nil
|
|
}
|