Skip to content

Commit

Permalink
Merge pull request #1835 from flant/kubernetes-client-keys-conflicts-fix
Browse files Browse the repository at this point in the history
fix: Handle Kubernetes API conflicts properly for signing keys
  • Loading branch information
bonifaido authored Oct 13, 2020
2 parents 9c02610 + 4801b2c commit 28b2350
Show file tree
Hide file tree
Showing 2 changed files with 146 additions and 2 deletions.
26 changes: 24 additions & 2 deletions storage/kubernetes/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"net/http"
"strings"
"time"

Expand Down Expand Up @@ -512,6 +513,7 @@ func (cli *client) UpdateKeys(updater func(old storage.Keys) (storage.Keys, erro
}
firstUpdate = true
}

var oldKeys storage.Keys
if !firstUpdate {
oldKeys = toStorageKeys(keys)
Expand All @@ -521,12 +523,32 @@ func (cli *client) UpdateKeys(updater func(old storage.Keys) (storage.Keys, erro
if err != nil {
return err
}

newKeys := cli.fromStorageKeys(updated)
if firstUpdate {
return cli.post(resourceKeys, newKeys)
err = cli.post(resourceKeys, newKeys)
if err != nil && errors.Is(err, storage.ErrAlreadyExists) {
// We need to tolerate conflicts here in case of HA mode.
cli.logger.Debugf("Keys creation failed: %v. It is possible that keys have already been created by another dex instance.", err)
return errors.New("keys already created by another server instance")
}

return err
}

newKeys.ObjectMeta = keys.ObjectMeta
return cli.put(resourceKeys, keysName, newKeys)

err = cli.put(resourceKeys, keysName, newKeys)
if httpErr, ok := err.(httpError); ok {
// We need to tolerate conflicts here in case of HA mode.
// Dex instances run keys rotation at the same time because they use SigningKey.nextRotation CR field as a trigger.
if httpErr.StatusCode() == http.StatusConflict {
cli.logger.Debugf("Keys rotation failed: %v. It is possible that keys have already been rotated by another dex instance.", err)
return errors.New("keys already rotated by another server instance")
}
}

return err
}

func (cli *client) UpdateAuthRequest(id string, updater func(a storage.AuthRequest) (storage.AuthRequest, error)) error {
Expand Down
122 changes: 122 additions & 0 deletions storage/kubernetes/storage_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
package kubernetes

import (
"crypto/tls"
"errors"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
Expand Down Expand Up @@ -150,3 +155,120 @@ func TestURLFor(t *testing.T) {
}
}
}

func TestUpdateKeys(t *testing.T) {
fakeUpdater := func(old storage.Keys) (storage.Keys, error) { return storage.Keys{}, nil }

tests := []struct {
name string
updater func(old storage.Keys) (storage.Keys, error)
getResponseCode int
actionResponseCode int
wantErr bool
exactErr error
}{
{
"Create OK test",
fakeUpdater,
404,
201,
false,
nil,
},
{
"Update should be OK",
fakeUpdater,
200,
200,
false,
nil,
},
{
"Create conflict should be OK",
fakeUpdater,
404,
409,
true,
errors.New("keys already created by another server instance"),
},
{
"Update conflict should be OK",
fakeUpdater,
200,
409,
true,
errors.New("keys already rotated by another server instance"),
},
{
"Client error is error",
fakeUpdater,
404,
500,
true,
nil,
},
{
"Client error during update is error",
fakeUpdater,
200,
500,
true,
nil,
},
{
"Get error is error",
fakeUpdater,
500,
200,
true,
nil,
},
{
"Updater error is error",
func(old storage.Keys) (storage.Keys, error) { return storage.Keys{}, fmt.Errorf("test") },
200,
201,
true,
nil,
},
}

for _, test := range tests {
client := newStatusCodesResponseTestClient(test.getResponseCode, test.actionResponseCode)

err := client.UpdateKeys(test.updater)
if err != nil {
if !test.wantErr {
t.Fatalf("Test %q: %v", test.name, err)
}

if test.exactErr != nil && test.exactErr.Error() != err.Error() {
t.Fatalf("Test %q: %v, wanted: %v", test.name, err, test.exactErr)
}
}
}
}

func newStatusCodesResponseTestClient(getResponseCode, actionResponseCode int) *client {
s := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodGet {
w.WriteHeader(getResponseCode)
} else {
w.WriteHeader(actionResponseCode)
}
w.Write([]byte(`{}`)) // Empty json is enough, we will test only response codes here
}))

tr := &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
return &client{
client: &http.Client{Transport: tr},
baseURL: s.URL,
logger: &logrus.Logger{
Out: os.Stderr,
Formatter: &logrus.TextFormatter{DisableColors: true},
Level: logrus.DebugLevel,
},
}
}

0 comments on commit 28b2350

Please sign in to comment.