Skip to content

Commit

Permalink
feat(auth): add support for providing custom certificate URL (googlea…
Browse files Browse the repository at this point in the history
  • Loading branch information
idhame authored Nov 25, 2024
1 parent 04738d8 commit ebf3657
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 71 deletions.
34 changes: 25 additions & 9 deletions auth/credentials/idtoken/validate.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,19 @@ type jwk struct {

// Validator provides a way to validate Google ID Tokens
type Validator struct {
client *cachingClient
client *cachingClient
rsa256CertsURL string
es256CertsURL string
}

// ValidatorOptions provides a way to configure a [Validator].
type ValidatorOptions struct {
// Client used to make requests to the certs URL. Optional.
Client *http.Client
// Custom certs URL for RSA256 JWK to be used. Optional.
RSA256CertsURL string
// Custom certs URL for ES256 JWK to be used. Optional.
ES256CertsURL string
}

// NewValidator creates a Validator that uses the options provided to configure
Expand All @@ -85,7 +91,17 @@ func NewValidator(opts *ValidatorOptions) (*Validator, error) {
} else {
client = internal.DefaultClient()
}
return &Validator{client: newCachingClient(client)}, nil

rsa256CertsURL := googleSACertsURL
es256CertsURL := googleIAPCertsURL
if opts != nil && opts.RSA256CertsURL != "" {
rsa256CertsURL = opts.RSA256CertsURL
}
if opts != nil && opts.ES256CertsURL != "" {
es256CertsURL = opts.ES256CertsURL
}

return &Validator{client: newCachingClient(client), rsa256CertsURL: rsa256CertsURL, es256CertsURL: es256CertsURL}, nil
}

// Validate is used to validate the provided idToken with a known Google cert
Expand Down Expand Up @@ -137,11 +153,11 @@ func (v *Validator) validate(ctx context.Context, idToken string, audience strin
hashedContent := hashHeaderPayload(idToken)
switch header.Algorithm {
case jwt.HeaderAlgRSA256:
if err := v.validateRS256(ctx, header.KeyID, hashedContent, sig); err != nil {
if err := v.validateRS256(ctx, header.KeyID, hashedContent, sig, v.rsa256CertsURL); err != nil {
return nil, err
}
case "ES256":
if err := v.validateES256(ctx, header.KeyID, hashedContent, sig); err != nil {
case jwt.HeaderAlgES256:
if err := v.validateES256(ctx, header.KeyID, hashedContent, sig, v.es256CertsURL); err != nil {
return nil, err
}
default:
Expand All @@ -151,8 +167,8 @@ func (v *Validator) validate(ctx context.Context, idToken string, audience strin
return payload, nil
}

func (v *Validator) validateRS256(ctx context.Context, keyID string, hashedContent []byte, sig []byte) error {
certResp, err := v.client.getCert(ctx, googleSACertsURL)
func (v *Validator) validateRS256(ctx context.Context, keyID string, hashedContent []byte, sig []byte, certsURL string) error {
certResp, err := v.client.getCert(ctx, certsURL)
if err != nil {
return err
}
Expand All @@ -176,8 +192,8 @@ func (v *Validator) validateRS256(ctx context.Context, keyID string, hashedConte
return rsa.VerifyPKCS1v15(pk, crypto.SHA256, hashedContent, sig)
}

func (v *Validator) validateES256(ctx context.Context, keyID string, hashedContent []byte, sig []byte) error {
certResp, err := v.client.getCert(ctx, googleIAPCertsURL)
func (v *Validator) validateES256(ctx context.Context, keyID string, hashedContent []byte, sig []byte, certsURL string) error {
certResp, err := v.client.getCert(ctx, certsURL)
if err != nil {
return err
}
Expand Down
164 changes: 102 additions & 62 deletions auth/credentials/idtoken/validate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,51 +49,70 @@ var (
func TestValidateRS256(t *testing.T) {
idToken, pk := createRS256JWT(t)
tests := []struct {
name string
keyID string
n *big.Int
e int
nowFunc func() time.Time
wantErr bool
name string
keyID string
certsURL string
n *big.Int
e int
nowFunc func() time.Time
wantErr bool
wantCertsURL string
}{
{
name: "works",
keyID: keyID,
n: pk.N,
e: pk.E,
nowFunc: beforeExp,
wantErr: false,
name: "works",
keyID: keyID,
n: pk.N,
e: pk.E,
nowFunc: beforeExp,
wantErr: false,
wantCertsURL: googleSACertsURL,
},
{
name: "no matching key",
keyID: "5678",
n: pk.N,
e: pk.E,
nowFunc: beforeExp,
wantErr: true,
name: "works with custom certs url",
keyID: keyID,
certsURL: "https://www.googleapis.com/service_accounts/v1/jwk/[email protected]",
n: pk.N,
e: pk.E,
nowFunc: beforeExp,
wantErr: false,
wantCertsURL: "https://www.googleapis.com/service_accounts/v1/jwk/[email protected]",
},
{
name: "sig does not match",
keyID: keyID,
n: new(big.Int).SetBytes([]byte("42")),
e: 42,
nowFunc: beforeExp,
wantErr: true,
name: "no matching key",
keyID: "5678",
n: pk.N,
e: pk.E,
nowFunc: beforeExp,
wantErr: true,
wantCertsURL: googleSACertsURL,
},
{
name: "token expired",
keyID: keyID,
n: pk.N,
e: pk.E,
nowFunc: afterExp,
wantErr: true,
name: "sig does not match",
keyID: keyID,
n: new(big.Int).SetBytes([]byte("42")),
e: 42,
nowFunc: beforeExp,
wantErr: true,
wantCertsURL: googleSACertsURL,
},
{
name: "token expired",
keyID: keyID,
n: pk.N,
e: pk.E,
nowFunc: afterExp,
wantErr: true,
wantCertsURL: googleSACertsURL,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
client := &http.Client{
Transport: RoundTripFn(func(req *http.Request) *http.Response {
if req.URL.String() != tt.wantCertsURL {
t.Fatalf("Invalid request uri, want %v got %v", tt.wantCertsURL, req.URL.String())
}
cr := certResponse{
Keys: []jwk{
{
Expand All @@ -119,7 +138,8 @@ func TestValidateRS256(t *testing.T) {
now = tt.nowFunc

v, err := NewValidator(&ValidatorOptions{
Client: client,
Client: client,
RSA256CertsURL: tt.certsURL,
})
if err != nil {
t.Fatalf("NewValidator(...) = %q, want nil", err)
Expand Down Expand Up @@ -162,50 +182,69 @@ func TestValidateRS256(t *testing.T) {
func TestValidateES256(t *testing.T) {
idToken, pk := createES256JWT(t)
tests := []struct {
name string
keyID string
x *big.Int
y *big.Int
nowFunc func() time.Time
wantErr bool
name string
keyID string
certsURL string
x *big.Int
y *big.Int
nowFunc func() time.Time
wantErr bool
wantCertsURL string
}{
{
name: "works",
keyID: keyID,
x: pk.X,
y: pk.Y,
nowFunc: beforeExp,
wantErr: false,
name: "works",
keyID: keyID,
x: pk.X,
y: pk.Y,
nowFunc: beforeExp,
wantErr: false,
wantCertsURL: googleIAPCertsURL,
},
{
name: "no matching key",
keyID: "5678",
x: pk.X,
y: pk.Y,
nowFunc: beforeExp,
wantErr: true,
name: "works with custom certs url",
keyID: keyID,
certsURL: "http://example.com",
x: pk.X,
y: pk.Y,
nowFunc: beforeExp,
wantErr: false,
wantCertsURL: "http://example.com",
},
{
name: "sig does not match",
keyID: keyID,
x: new(big.Int),
y: new(big.Int),
nowFunc: beforeExp,
wantErr: true,
name: "no matching key",
keyID: "5678",
x: pk.X,
y: pk.Y,
nowFunc: beforeExp,
wantErr: true,
wantCertsURL: googleIAPCertsURL,
},
{
name: "token expired",
keyID: keyID,
x: pk.X,
y: pk.Y,
nowFunc: afterExp,
wantErr: true,
name: "sig does not match",
keyID: keyID,
x: new(big.Int),
y: new(big.Int),
nowFunc: beforeExp,
wantErr: true,
wantCertsURL: googleIAPCertsURL,
},
{
name: "token expired",
keyID: keyID,
x: pk.X,
y: pk.Y,
nowFunc: afterExp,
wantErr: true,
wantCertsURL: googleIAPCertsURL,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
client := &http.Client{
Transport: RoundTripFn(func(req *http.Request) *http.Response {
if req.URL.String() != tt.wantCertsURL {
t.Fatalf("Invalid request uri, want %v got %v", tt.wantCertsURL, req.URL.String())
}
cr := certResponse{
Keys: []jwk{
{
Expand All @@ -231,7 +270,8 @@ func TestValidateES256(t *testing.T) {
now = tt.nowFunc

v, err := NewValidator(&ValidatorOptions{
Client: client,
Client: client,
ES256CertsURL: tt.certsURL,
})
if err != nil {
t.Fatalf("NewValidator(...) = %q, want nil", err)
Expand Down

0 comments on commit ebf3657

Please sign in to comment.