From e007fd12fdc544e359d1b0e4ae10f3a345d7de96 Mon Sep 17 00:00:00 2001 From: Ben Johnson Date: Wed, 26 Jul 2023 08:28:57 -0600 Subject: [PATCH] Retry proxy on connection refused error (#368) --- http/proxy_server.go | 42 +++++++++++++++++++++++++++++++++++++++++- 1 file changed, 41 insertions(+), 1 deletion(-) diff --git a/http/proxy_server.go b/http/proxy_server.go index e01acd7..7d6a3ff 100644 --- a/http/proxy_server.go +++ b/http/proxy_server.go @@ -2,12 +2,14 @@ package http import ( "context" + "errors" "fmt" "io" "log" "net" "net/http" "regexp" + "syscall" "time" "github.com/superfly/litefs" @@ -61,6 +63,8 @@ type ProxyServer struct { // Time before cookie expires on client. CookieExpiry time.Duration + + HTTPTransport *http.Transport } // NewProxyServer returns a new instance of ProxyServer. @@ -79,6 +83,19 @@ func NewProxyServer(store *litefs.Store) *ProxyServer { Handler: http.HandlerFunc(s.serveHTTP), } + s.HTTPTransport = &http.Transport{ + Proxy: http.ProxyFromEnvironment, + DialContext: dialContextWithRetry(&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + }), + ForceAttemptHTTP2: true, + MaxIdleConns: 100, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + } + return s } @@ -238,7 +255,7 @@ func (s *ProxyServer) proxyToTarget(w http.ResponseWriter, r *http.Request, pass r.URL.Scheme = "http" r.URL.Host = s.Target - resp, err := http.DefaultTransport.RoundTrip(r) + resp, err := s.HTTPTransport.RoundTrip(r) if err != nil { http.Error(w, "Proxy error: "+err.Error(), http.StatusBadGateway) return @@ -295,3 +312,26 @@ func (s *ProxyServer) logf(format string, v ...any) { log.Printf(format, v...) } } + +// dialContextWithRetry returns a function that will retry +func dialContextWithRetry(dialer *net.Dialer) func(context.Context, string, string) (net.Conn, error) { + return func(ctx context.Context, network, address string) (net.Conn, error) { + timeout := time.NewTimer(dialer.Timeout) + defer timeout.Stop() + + for { + conn, err := dialer.DialContext(ctx, network, address) + if !errors.Is(err, syscall.ECONNREFUSED) { + return conn, err + } + + select { + case <-ctx.Done(): + return nil, context.Cause(ctx) + case <-timeout.C: + return nil, err + case <-time.After(100 * time.Millisecond): + } + } + } +}