From 5827995ccbaebf99da1cd7c73a2d6e03b541a944 Mon Sep 17 00:00:00 2001 From: Brad Davidson Date: Thu, 14 Nov 2024 02:04:37 +0000 Subject: [PATCH] Fix issue with loadbalancer failover to default server The loadbalancer should only fail over to the default server if all other server have failed, and it should force fail-back to a preferred server as soon as one passes health checks. The loadbalancer tests have been improved to ensure that this occurs. Signed-off-by: Brad Davidson --- pkg/agent/loadbalancer/loadbalancer.go | 2 + pkg/agent/loadbalancer/loadbalancer_test.go | 186 +++++++++++++++++--- pkg/agent/loadbalancer/servers.go | 47 +++-- pkg/agent/loadbalancer/utility.go | 3 + pkg/util/apierrors.go | 3 +- 5 files changed, 200 insertions(+), 41 deletions(-) diff --git a/pkg/agent/loadbalancer/loadbalancer.go b/pkg/agent/loadbalancer/loadbalancer.go index c75ea5fec4f2..6689a9e7ca39 100644 --- a/pkg/agent/loadbalancer/loadbalancer.go +++ b/pkg/agent/loadbalancer/loadbalancer.go @@ -179,6 +179,8 @@ func (lb *LoadBalancer) dialContext(ctx context.Context, network, _ string) (net if !allChecksFailed { defer server.closeAll() } + } else { + logrus.Debugf("Dial health check failed for %s", targetServer) } newServer, err := lb.nextServer(targetServer) diff --git a/pkg/agent/loadbalancer/loadbalancer_test.go b/pkg/agent/loadbalancer/loadbalancer_test.go index 1cb26736e07c..cbfdf982c690 100644 --- a/pkg/agent/loadbalancer/loadbalancer_test.go +++ b/pkg/agent/loadbalancer/loadbalancer_test.go @@ -10,7 +10,6 @@ import ( "testing" "time" - "github.com/k3s-io/k3s/pkg/cli/cmds" "github.com/sirupsen/logrus" ) @@ -24,7 +23,7 @@ type testServer struct { prefix string } -func createServer(prefix string) (*testServer, error) { +func createServer(ctx context.Context, prefix string) (*testServer, error) { listener, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { return nil, err @@ -34,6 +33,10 @@ func createServer(prefix string) (*testServer, error) { listener: listener, } go s.serve() + go func() { + <-ctx.Done() + s.close() + }() return s, nil } @@ -49,6 +52,7 @@ func (s *testServer) serve() { } func (s *testServer) close() { + logrus.Printf("testServer %s closing", s.prefix) s.listener.Close() for _, conn := range s.conns { conn.Close() @@ -65,6 +69,10 @@ func (s *testServer) echo(conn net.Conn) { } } +func (s *testServer) address() string { + return s.listener.Addr().String() +} + func ping(conn net.Conn) (string, error) { fmt.Fprintf(conn, "ping\n") result, err := bufio.NewReader(conn).ReadString('\n') @@ -74,25 +82,31 @@ func ping(conn net.Conn) (string, error) { return strings.TrimSpace(result), nil } +// Test_UnitFailOver creates a LB using a default server (ie fixed registration endpoint) +// and then adds a new server (a node). The node server is then closed, and it is confirmed +// that new connections use the default server. func Test_UnitFailOver(t *testing.T) { tmpDir := t.TempDir() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() - ogServe, err := createServer("og") + defaultServer, err := createServer(ctx, "default") if err != nil { - t.Fatalf("createServer(og) failed: %v", err) + t.Fatalf("createServer(default) failed: %v", err) } - lbServe, err := createServer("lb") + node1Server, err := createServer(ctx, "node1") if err != nil { - t.Fatalf("createServer(lb) failed: %v", err) + t.Fatalf("createServer(node1) failed: %v", err) } - cfg := cmds.Agent{ - ServerURL: fmt.Sprintf("http://%s/", ogServe.listener.Addr().String()), - DataDir: tmpDir, + node2Server, err := createServer(ctx, "node2") + if err != nil { + t.Fatalf("createServer(node2) failed: %v", err) } - lb, err := New(context.TODO(), cfg.DataDir, SupervisorServiceName, cfg.ServerURL, RandomPort, false) + // start the loadbalancer with the default server as the only server + lb, err := New(ctx, tmpDir, SupervisorServiceName, "http://"+defaultServer.address(), RandomPort, false) if err != nil { t.Fatalf("New() failed: %v", err) } @@ -103,50 +117,123 @@ func Test_UnitFailOver(t *testing.T) { } localAddress := parsedURL.Host - lb.Update([]string{lbServe.listener.Addr().String()}) + // add the node as a new server address. + lb.Update([]string{node1Server.address()}) + // make sure connections go to the node conn1, err := net.Dial("tcp", localAddress) if err != nil { t.Fatalf("net.Dial failed: %v", err) } - result1, err := ping(conn1) - if err != nil { + if result, err := ping(conn1); err != nil { t.Fatalf("ping(conn1) failed: %v", err) + } else if result != "node1:ping" { + t.Fatalf("Unexpected ping(conn1) result: %v", result) } - if result1 != "lb:ping" { - t.Fatalf("Unexpected ping result: %v", result1) - } - lbServe.close() + t.Log("conn1 tested OK") + + // set failing health check for node 1 + lb.SetHealthCheck(node1Server.address(), func() bool { return false }) + + // Server connections are checked every second, now that node 1 is failed + // the connections to it should be closed. + time.Sleep(2 * time.Second) - _, err = ping(conn1) - if err == nil { + if _, err := ping(conn1); err == nil { t.Fatal("Unexpected successful ping on closed connection conn1") } + t.Log("conn1 closed on failure OK") + + // make sure connection still goes to the first node - it is failing health checks but so + // is the default endpoint, so it should be tried first with health checks disabled, + // before failing back to the default. conn2, err := net.Dial("tcp", localAddress) if err != nil { t.Fatalf("net.Dial failed: %v", err) } - result2, err := ping(conn2) - if err != nil { + if result, err := ping(conn2); err != nil { t.Fatalf("ping(conn2) failed: %v", err) + } else if result != "node1:ping" { + t.Fatalf("Unexpected ping(conn2) result: %v", result) } - if result2 != "og:ping" { - t.Fatalf("Unexpected ping result: %v", result2) + + t.Log("conn2 tested OK") + + // make sure the health checks don't close the connection we just made - + // connections should only be closed when it transitions from health to unhealthy. + time.Sleep(2 * time.Second) + + if result, err := ping(conn2); err != nil { + t.Fatalf("ping(conn2) failed: %v", err) + } else if result != "node1:ping" { + t.Fatalf("Unexpected ping(conn2) result: %v", result) } + + t.Log("conn2 tested OK again") + + // shut down the first node server to force failover to the default + node1Server.close() + + // make sure new connections go to the default, and existing connections are closed + conn3, err := net.Dial("tcp", localAddress) + if err != nil { + t.Fatalf("net.Dial failed: %v", err) + + } + if result, err := ping(conn3); err != nil { + t.Fatalf("ping(conn3) failed: %v", err) + } else if result != "default:ping" { + t.Fatalf("Unexpected ping(conn3) result: %v", result) + } + + t.Log("conn3 tested OK") + + if _, err := ping(conn2); err == nil { + t.Fatal("Unexpected successful ping on closed connection conn2") + } + + t.Log("conn2 closed on failure OK") + + // add the second node as a new server address. + lb.Update([]string{node2Server.address()}) + + // make sure connection now goes to the second node, + // and connections to the default are closed. + conn4, err := net.Dial("tcp", localAddress) + if err != nil { + t.Fatalf("net.Dial failed: %v", err) + + } + if result, err := ping(conn4); err != nil { + t.Fatalf("ping(conn4) failed: %v", err) + } else if result != "node2:ping" { + t.Fatalf("Unexpected ping(conn4) result: %v", result) + } + + t.Log("conn4 tested OK") + + // Server connections are checked every second, now that we have a healthy + // server, connections to the default server should be closed + time.Sleep(2 * time.Second) + + if _, err := ping(conn3); err == nil { + t.Fatal("Unexpected successful ping on connection conn3") + } + + t.Log("conn3 closed on failure OK") } +// Test_UnitFailFast confirms that connnections to invalid addresses fail quickly func Test_UnitFailFast(t *testing.T) { tmpDir := t.TempDir() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() - cfg := cmds.Agent{ - ServerURL: "http://127.0.0.1:0/", - DataDir: tmpDir, - } - - lb, err := New(context.TODO(), cfg.DataDir, SupervisorServiceName, cfg.ServerURL, RandomPort, false) + serverURL := "http://127.0.0.1:0/" + lb, err := New(ctx, tmpDir, SupervisorServiceName, serverURL, RandomPort, false) if err != nil { t.Fatalf("New() failed: %v", err) } @@ -172,3 +259,44 @@ func Test_UnitFailFast(t *testing.T) { t.Fatal("Test timed out") } } + +// Test_UnitFailUnreachable confirms that connnections to unreachable addresses do fail +// within the expected duration +func Test_UnitFailUnreachable(t *testing.T) { + if testing.Short() { + t.Skip("skipping slow test in short mode.") + } + tmpDir := t.TempDir() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + serverAddr := "192.0.2.1:6443" + lb, err := New(ctx, tmpDir, SupervisorServiceName, "http://"+serverAddr, RandomPort, false) + if err != nil { + t.Fatalf("New() failed: %v", err) + } + + // Set failing health check to reduce retries + lb.SetHealthCheck(serverAddr, func() bool { return false }) + + conn, err := net.Dial("tcp", lb.localAddress) + if err != nil { + t.Fatalf("net.Dial failed: %v", err) + } + + done := make(chan error) + go func() { + _, err = ping(conn) + done <- err + }() + timeout := time.After(11 * time.Second) + + select { + case err := <-done: + if err == nil { + t.Fatal("Unexpected successful ping from unreachable address") + } + case <-timeout: + t.Fatal("Test timed out") + } +} diff --git a/pkg/agent/loadbalancer/servers.go b/pkg/agent/loadbalancer/servers.go index 6b7f25606064..660810525470 100644 --- a/pkg/agent/loadbalancer/servers.go +++ b/pkg/agent/loadbalancer/servers.go @@ -7,6 +7,7 @@ import ( "net" "net/url" "os" + "slices" "strconv" "time" @@ -21,7 +22,10 @@ import ( "k8s.io/apimachinery/pkg/util/wait" ) -var defaultDialer proxy.Dialer = &net.Dialer{} +var defaultDialer proxy.Dialer = &net.Dialer{ + Timeout: 10 * time.Second, + KeepAlive: 30 * time.Second, +} // SetHTTPProxy configures a proxy-enabled dialer to be used for all loadbalancer connections, // if the agent has been configured to allow use of a HTTP proxy, and the environment has been configured @@ -48,7 +52,7 @@ func SetHTTPProxy(address string) error { return nil } - dialer, err := proxyDialer(proxyURL) + dialer, err := proxyDialer(proxyURL, defaultDialer) if err != nil { return errors.Wrapf(err, "failed to create proxy dialer for %s", proxyURL) } @@ -59,7 +63,7 @@ func SetHTTPProxy(address string) error { } func (lb *LoadBalancer) setServers(serverAddresses []string) bool { - serverAddresses, hasOriginalServer := sortServers(serverAddresses, lb.defaultServerAddress) + serverAddresses, hasDefaultServer := sortServers(serverAddresses, lb.defaultServerAddress) if len(serverAddresses) == 0 { return false } @@ -102,8 +106,16 @@ func (lb *LoadBalancer) setServers(serverAddresses []string) bool { rand.Shuffle(len(lb.randomServers), func(i, j int) { lb.randomServers[i], lb.randomServers[j] = lb.randomServers[j], lb.randomServers[i] }) - if !hasOriginalServer { + // If the current server list does not contain the default server address, + // we want to include it in the random server list so that it can be tried if necessary. + // However, it should be treated as always failing health checks so that it is only + // used if all other endpoints are unavailable. + if !hasDefaultServer { lb.randomServers = append(lb.randomServers, lb.defaultServerAddress) + if defaultServer, ok := lb.servers[lb.defaultServerAddress]; ok { + defaultServer.healthCheck = func() bool { return false } + lb.servers[lb.defaultServerAddress] = defaultServer + } } lb.currentServerAddress = lb.randomServers[0] lb.nextServerIndex = 1 @@ -163,14 +175,14 @@ func (s *server) dialContext(ctx context.Context, network, address string) (net. } // proxyDialer creates a new proxy.Dialer that routes connections through the specified proxy. -func proxyDialer(proxyURL *url.URL) (proxy.Dialer, error) { +func proxyDialer(proxyURL *url.URL, forward proxy.Dialer) (proxy.Dialer, error) { if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" { // Create a new HTTP proxy dialer - httpProxyDialer := http_dialer.New(proxyURL) + httpProxyDialer := http_dialer.New(proxyURL, http_dialer.WithDialer(forward.(*net.Dialer))) return httpProxyDialer, nil } else if proxyURL.Scheme == "socks5" { // For SOCKS5 proxies, use the proxy package's FromURL - return proxy.FromURL(proxyURL, proxy.Direct) + return proxy.FromURL(proxyURL, forward) } return nil, fmt.Errorf("unsupported proxy scheme: %s", proxyURL.Scheme) } @@ -204,17 +216,18 @@ func (lb *LoadBalancer) SetDefault(serverAddress string) { lb.mutex.Lock() defer lb.mutex.Unlock() - _, hasOriginalServer := sortServers(lb.ServerAddresses, lb.defaultServerAddress) + hasDefaultServer := slices.Contains(lb.ServerAddresses, lb.defaultServerAddress) // if the old default server is not currently in use, remove it from the server map - if server := lb.servers[lb.defaultServerAddress]; server != nil && !hasOriginalServer { + if server := lb.servers[lb.defaultServerAddress]; server != nil && !hasDefaultServer { defer server.closeAll() delete(lb.servers, lb.defaultServerAddress) } - // if the new default server doesn't have an entry in the map, add one + // if the new default server doesn't have an entry in the map, add one - but + // with a failing health check so that it is only used as a last resort. if _, ok := lb.servers[serverAddress]; !ok { lb.servers[serverAddress] = &server{ address: serverAddress, - healthCheck: func() bool { return true }, + healthCheck: func() bool { return false }, connections: make(map[net.Conn]struct{}), } } @@ -243,8 +256,10 @@ func (lb *LoadBalancer) runHealthChecks(ctx context.Context) { wait.Until(func() { lb.mutex.RLock() defer lb.mutex.RUnlock() + var healthyServerExists bool for address, server := range lb.servers { status := server.healthCheck() + healthyServerExists = healthyServerExists || status if status == false && previousStatus[address] == true { // Only close connections when the server transitions from healthy to unhealthy; // we don't want to re-close all the connections every time as we might be ignoring @@ -253,6 +268,16 @@ func (lb *LoadBalancer) runHealthChecks(ctx context.Context) { } previousStatus[address] = status } + + // If there is at least one healthy server, and the default server is not in the server list, + // close all the connections to the default server so that clients reconnect and switch over + // to a preferred server. + hasDefaultServer := slices.Contains(lb.ServerAddresses, lb.defaultServerAddress) + if healthyServerExists && !hasDefaultServer { + if server, ok := lb.servers[lb.defaultServerAddress]; ok { + defer server.closeAll() + } + } }, time.Second, ctx.Done()) logrus.Debugf("Stopped health checking for load balancer %s", lb.serviceName) } diff --git a/pkg/agent/loadbalancer/utility.go b/pkg/agent/loadbalancer/utility.go index a462da2e2349..7ecff5464412 100644 --- a/pkg/agent/loadbalancer/utility.go +++ b/pkg/agent/loadbalancer/utility.go @@ -28,6 +28,9 @@ func parseURL(serverURL, newHost string) (string, string, error) { return address, parsedURL.String(), nil } +// sortServers returns a sorted, unique list of strings, with any +// empty values removed. The returned bool is true if the list +// contains the search string. func sortServers(input []string, search string) ([]string, bool) { result := []string{} found := false diff --git a/pkg/util/apierrors.go b/pkg/util/apierrors.go index ec61ecea5465..8650dbe01d14 100644 --- a/pkg/util/apierrors.go +++ b/pkg/util/apierrors.go @@ -40,7 +40,7 @@ func SendError(err error, resp http.ResponseWriter, req *http.Request, status .. // Don't log "apiserver not ready" or "apiserver disabled" errors, they are frequent during startup if !errors.Is(err, ErrAPINotReady) && !errors.Is(err, ErrAPIDisabled) { - logrus.Errorf("Sending HTTP %d response to %s: %v", code, req.RemoteAddr, err) + logrus.Errorf("Sending %s %d response to %s: %v", req.Proto, code, req.RemoteAddr, err) } var serr *apierrors.StatusError @@ -61,6 +61,7 @@ func SendError(err error, resp http.ResponseWriter, req *http.Request, status .. serr = apierrors.NewGenericServerResponse(code, req.Method, schema.GroupResource{}, req.URL.Path, err.Error(), 0, true) } + resp.Header().Add("Connection", "close") responsewriters.ErrorNegotiated(serr, scheme.Codecs.WithoutConversion(), schema.GroupVersion{}, resp, req) }