Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support multiple certs for Encryption #73

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 52 additions & 30 deletions decode_response.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,52 +98,70 @@ func xmlUnmarshalElement(el *etree.Element, obj interface{}) error {
return nil
}

func (sp *SAMLServiceProvider) getDecryptCert() (*tls.Certificate, error) {
func (sp *SAMLServiceProvider) getDecryptCert() ([]*tls.Certificate, error) {
if sp.SPKeyStore == nil {
return nil, fmt.Errorf("no decryption certs available")
}

//This is the tls.Certificate we'll use to decrypt any encrypted assertions
var decryptCert tls.Certificate
extractCert := func(keyStore dsig.X509KeyStore) (*tls.Certificate, error) {
//This is the tls.Certificate we'll use to decrypt any encrypted assertions
var decryptCert tls.Certificate

switch crt := sp.SPKeyStore.(type) {
case dsig.TLSCertKeyStore:
// Get the tls.Certificate directly if possible
decryptCert = tls.Certificate(crt)
switch crt := keyStore.(type) {
case dsig.TLSCertKeyStore:
// Get the tls.Certificate directly if possible
decryptCert = tls.Certificate(crt)

default:
default:

//Otherwise, construct one from the results of GetKeyPair
pk, cert, err := sp.SPKeyStore.GetKeyPair()
if err != nil {
return nil, fmt.Errorf("error getting keypair: %v", err)
//Otherwise, construct one from the results of GetKeyPair
pk, cert, err := keyStore.GetKeyPair()
if err != nil {
return nil, fmt.Errorf("error getting keypair: %v", err)
}

decryptCert = tls.Certificate{
Certificate: [][]byte{cert},
PrivateKey: pk,
}
}

decryptCert = tls.Certificate{
Certificate: [][]byte{cert},
PrivateKey: pk,
if sp.ValidateEncryptionCert {
// Check Validity period of certificate
if len(decryptCert.Certificate) < 1 || len(decryptCert.Certificate[0]) < 1 {
return nil, fmt.Errorf("empty decryption cert")
} else if cert, err := x509.ParseCertificate(decryptCert.Certificate[0]); err != nil {
return nil, fmt.Errorf("invalid x509 decryption cert: %v", err)
} else {
now := sp.Clock.Now()
if now.Before(cert.NotBefore) || now.After(cert.NotAfter) {
return nil, fmt.Errorf("decryption cert is not valid at this time")
}
}
}

return &decryptCert, nil
}

if sp.ValidateEncryptionCert {
// Check Validity period of certificate
if len(decryptCert.Certificate) < 1 || len(decryptCert.Certificate[0]) < 1 {
return nil, fmt.Errorf("empty decryption cert")
} else if cert, err := x509.ParseCertificate(decryptCert.Certificate[0]); err != nil {
return nil, fmt.Errorf("invalid x509 decryption cert: %v", err)
} else {
now := sp.Clock.Now()
if now.Before(cert.NotBefore) || now.After(cert.NotAfter) {
return nil, fmt.Errorf("decryption cert is not valid at this time")
}
var decryptionCerts []*tls.Certificate
availableKeyStores := []dsig.X509KeyStore{sp.SPKeyStore, sp.SPKeyStoreRotate}
for _, keyStore := range availableKeyStores {
if keyStore == nil {
continue
}
decryptionCert, err := extractCert(keyStore)
if err != nil {
return decryptionCerts, err
}
decryptionCerts = append(decryptionCerts, decryptionCert)
}

return &decryptCert, nil
return decryptionCerts, nil
}

func (sp *SAMLServiceProvider) decryptAssertions(el *etree.Element) error {
var decryptCert *tls.Certificate
var decryptCert []*tls.Certificate
var elementFound bool

decryptAssertion := func(ctx etreeutils.NSContext, encryptedElement *etree.Element) error {
if encryptedElement.Parent() != el {
Expand All @@ -161,7 +179,7 @@ func (sp *SAMLServiceProvider) decryptAssertions(el *etree.Element) error {
return fmt.Errorf("unable to unmarshal encrypted assertion: %v", err)
}

if decryptCert == nil {
if len(decryptCert) == 0 {
decryptCert, err = sp.getDecryptCert()
if err != nil {
return fmt.Errorf("unable to get decryption certificate: %v", err)
Expand All @@ -185,12 +203,16 @@ func (sp *SAMLServiceProvider) decryptAssertions(el *etree.Element) error {
}

el.AddChild(doc.Root())
elementFound = true
return nil
}

if err := etreeutils.NSFindIterate(el, SAMLAssertionNamespace, EncryptedAssertionTag, decryptAssertion); err != nil {
return err
} else {
if sp.RequireEncryptedAssertion && !elementFound {
return fmt.Errorf("encrypted assertion required, not found")
}
return nil
}
}
Expand Down Expand Up @@ -441,4 +463,4 @@ func (sp *SAMLServiceProvider) ValidateEncodedLogoutResponsePOST(encodedResponse
}

return decodedResponse, nil
}
}
14 changes: 7 additions & 7 deletions providertests/onelogin_test.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
// Copyright 2016 Russell Haering et al.
//
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//
// https://www.apache.org/licenses/LICENSE-2.0
//
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
Expand Down Expand Up @@ -110,13 +110,13 @@ var oneLoginScenarioErrors = map[int]string{
// 38 - signed(Response(encrypted(signed(Assertion)))) - 08 wrong IDP signing cert, correct SP encryption cert
38: "error validating response: Could not verify certificate against trusted certs",
// 97 - Response(encrypted(Assertion)) - 99 wrong SP encryption cert
97: "error validating response: unable to decrypt encrypted assertion: cannot decrypt, error retrieving private key: key decryption attempted with mismatched cert, SP cert(cd:f6:7c:e9), assertion cert(42:99:58:b8)",
97: "error validating response: unable to decrypt encrypted assertion: cannot decrypt, error retrieving private key: matching cert not found, assertion cert(42:99:58:b8)",
// 46 - Response(encrypted(signed(Assertion))) - 06 wrong SP encryption cert, correct IDP signing cert
46: "error validating response: unable to decrypt encrypted assertion: cannot decrypt, error retrieving private key: key decryption attempted with mismatched cert, SP cert(cd:f6:7c:e9), assertion cert(42:99:58:b8)",
46: "error validating response: unable to decrypt encrypted assertion: cannot decrypt, error retrieving private key: matching cert not found, assertion cert(42:99:58:b8)",
// 47 - signed(Response(encrypted(Assertion))) - 07 wrong SP encryption cert, correct IDP signing cert
47: "error validating response: unable to decrypt encrypted assertion: cannot decrypt, error retrieving private key: key decryption attempted with mismatched cert, SP cert(cd:f6:7c:e9), assertion cert(42:99:58:b8)",
47: "error validating response: unable to decrypt encrypted assertion: cannot decrypt, error retrieving private key: matching cert not found, assertion cert(42:99:58:b8)",
// 48 - signed(Response(encrypted(signed(Assertion)))) - 08 wrong SP encryption cert, correct IDP signing cert
48: "error validating response: unable to decrypt encrypted assertion: cannot decrypt, error retrieving private key: key decryption attempted with mismatched cert, SP cert(cd:f6:7c:e9), assertion cert(42:99:58:b8)",
48: "error validating response: unable to decrypt encrypted assertion: cannot decrypt, error retrieving private key: matching cert not found, assertion cert(42:99:58:b8)",
// 85 - Response(Assertion) - 99 empty Response Destination (empty is ok, Destination is optional)
// Note: gosaml2 is correctly checking signature before contents
85: "error validating response: response and/or assertions must be signed",
Expand Down
26 changes: 14 additions & 12 deletions saml.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,18 +54,20 @@ type SAMLServiceProvider struct {
// provider use specific authentication mechanisms. Leaving this unset will
// permit the identity provider to choose the auth method. To maximize compatibility
// with identity providers it is recommended to leave this unset.
RequestedAuthnContext *RequestedAuthnContext
AudienceURI string
IDPCertificateStore dsig.X509CertificateStore
SPKeyStore dsig.X509KeyStore // Required encryption key, default signing key
SPSigningKeyStore dsig.X509KeyStore // Optional signing key
NameIdFormat string
ValidateEncryptionCert bool
SkipSignatureValidation bool
AllowMissingAttributes bool
Clock *dsig.Clock
signingContextMu sync.RWMutex
signingContext *dsig.SigningContext
RequestedAuthnContext *RequestedAuthnContext
AudienceURI string
IDPCertificateStore dsig.X509CertificateStore
SPKeyStore dsig.X509KeyStore // Required encryption key, default signing key
SPSigningKeyStore dsig.X509KeyStore // Optional signing key
SPKeyStoreRotate dsig.X509KeyStore // Optional additional encryption key for rotation
NameIdFormat string
ValidateEncryptionCert bool
SkipSignatureValidation bool
AllowMissingAttributes bool
Clock *dsig.Clock
RequireEncryptedAssertion bool
signingContextMu sync.RWMutex
signingContext *dsig.SigningContext
}

// RequestedAuthnContext controls which authentication mechanisms are requested of
Expand Down
4 changes: 2 additions & 2 deletions saml_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ func TestDecode(t *testing.T) {

ea := response.EncryptedAssertions[0]

k, err := ea.EncryptedKey.DecryptSymmetricKey(&cert)
k, err := ea.EncryptedKey.DecryptSymmetricKey([]*tls.Certificate{&cert})
if err != nil {
t.Fatalf("could not get symmetric key: %v\n", err)
}
Expand All @@ -73,7 +73,7 @@ func TestDecode(t *testing.T) {
t.Fatalf("no symmetric key")
}

assertion, err := ea.Decrypt(&cert)
assertion, err := ea.Decrypt([]*tls.Certificate{&cert})
if err != nil {
t.Fatalf("error decrypting saml data: %v\n", err)
}
Expand Down
29 changes: 22 additions & 7 deletions types/encrypted_assertion.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
// Copyright 2016 Russell Haering et al.
//
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//
// https://www.apache.org/licenses/LICENSE-2.0
//
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
Expand All @@ -15,7 +15,9 @@ package types

import (
"bytes"
"crypto/aes"
"crypto/cipher"
"crypto/des"
"crypto/tls"
"encoding/base64"
"encoding/xml"
Expand All @@ -30,7 +32,7 @@ type EncryptedAssertion struct {
CipherValue string `xml:"EncryptedData>CipherData>CipherValue"`
}

func (ea *EncryptedAssertion) DecryptBytes(cert *tls.Certificate) ([]byte, error) {
func (ea *EncryptedAssertion) DecryptBytes(certs []*tls.Certificate) ([]byte, error) {
data, err := base64.StdEncoding.DecodeString(ea.CipherValue)
if err != nil {
return nil, err
Expand All @@ -43,13 +45,17 @@ func (ea *EncryptedAssertion) DecryptBytes(cert *tls.Certificate) ([]byte, error
// https://www.w3.org/TR/2002/REC-xmlenc-core-20021210/Overview.html#sec-Extensions-to-KeyInfo
ek = &ea.DetEncryptedKey
}
k, err := ek.DecryptSymmetricKey(cert)
keyBytes, err := ek.DecryptSymmetricKey(certs)
if err != nil {
return nil, fmt.Errorf("cannot decrypt, error retrieving private key: %s", err)
}

switch ea.EncryptionMethod.Algorithm {
case MethodAES128GCM:
k, err := aes.NewCipher(keyBytes)
if err != nil {
return nil, err
}
c, err := cipher.NewGCM(k)
if err != nil {
return nil, fmt.Errorf("cannot create AES-GCM: %s", err)
Expand All @@ -62,6 +68,15 @@ func (ea *EncryptedAssertion) DecryptBytes(cert *tls.Certificate) ([]byte, error
}
return plainText, nil
case MethodAES128CBC, MethodAES256CBC, MethodTripleDESCBC:
var k cipher.Block
if ea.EncryptionMethod.Algorithm == MethodTripleDESCBC {
k, err = des.NewTripleDESCipher(keyBytes)
} else {
k, err = aes.NewCipher(keyBytes)
}
if err != nil {
return nil, err
}
nonce, data := data[:k.BlockSize()], data[k.BlockSize():]
c := cipher.NewCBCDecrypter(k, nonce)
c.CryptBlocks(data, data)
Expand All @@ -79,8 +94,8 @@ func (ea *EncryptedAssertion) DecryptBytes(cert *tls.Certificate) ([]byte, error
}

// Decrypt decrypts and unmarshals the EncryptedAssertion.
func (ea *EncryptedAssertion) Decrypt(cert *tls.Certificate) (*Assertion, error) {
plaintext, err := ea.DecryptBytes(cert)
func (ea *EncryptedAssertion) Decrypt(certs []*tls.Certificate) (*Assertion, error) {
plaintext, err := ea.DecryptBytes(certs)
if err != nil {
return nil, fmt.Errorf("Error decrypting assertion: %v", err)
}
Expand Down
Loading