From 3562de8f21f355d87953551bc8e1fda69ea43be1 Mon Sep 17 00:00:00 2001 From: Ryan Canty Date: Thu, 5 Jan 2023 14:42:32 -0800 Subject: [PATCH] Support multiple LDAP hosts Often there is a need to iterate through multiple LDAP hosts in the event that one has failed. DNS is one possible solution, but there are cases where that may not be desired. This change allows users to specify a comma-separated list of LDAP hosts to iterate through in case one is unreachable. Signed-off-by: Ryan Canty --- connector/ldap/ldap.go | 55 +++++++++++++++++++++------- connector/ldap/ldap_test.go | 73 +++++++++++++++++++++++++++++++++++-- 2 files changed, 110 insertions(+), 18 deletions(-) diff --git a/connector/ldap/ldap.go b/connector/ldap/ldap.go index 543402718c..a009128cf6 100644 --- a/connector/ldap/ldap.go +++ b/connector/ldap/ldap.go @@ -9,6 +9,8 @@ import ( "fmt" "net" "os" + "strings" + "time" "github.com/go-ldap/ldap/v3" @@ -16,6 +18,11 @@ import ( "github.com/dexidp/dex/pkg/log" ) +const ( + secureLdapPort = "636" + ldapPort = "389" +) + // Config holds the configuration parameters for the LDAP connector. The LDAP // connectors require executing two queries, the first to find the user based on // the username and password given to the connector. The second to use the user @@ -65,7 +72,8 @@ type UserMatcher struct { // Config holds configuration options for LDAP logins. type Config struct { // The host and optional port of the LDAP server. If port isn't supplied, it will be - // guessed based on the TLS configuration. 389 or 636. + // guessed based on the TLS configuration. 389 or 636. Can be a comma-separated list + // of hosts that will be tried iteratively. Host string `json:"host"` // Required if LDAP host does not use TLS. @@ -241,19 +249,8 @@ func (c *Config) openConnector(logger log.Logger) (*ldapConnector, error) { } } - var ( - host string - err error - ) - if host, _, err = net.SplitHostPort(c.Host); err != nil { - host = c.Host - if c.InsecureNoSSL { - c.Host += ":389" - } else { - c.Host += ":636" - } - } - + host, port := c.getHostPort() + c.Host = net.JoinHostPort(host, port) tlsConfig := &tls.Config{ServerName: host, InsecureSkipVerify: c.InsecureSkipVerify} if c.RootCA != "" || len(c.RootCAData) != 0 { data := c.RootCAData @@ -635,3 +632,33 @@ func (c *ldapConnector) groups(ctx context.Context, user ldap.Entry) ([]string, func (c *ldapConnector) Prompt() string { return c.UsernamePrompt } + +// getHostPort splits the Host attribute on comma and tests each connection +// if it fails, the next is set. The final host, port is returned if none succeed. +// For single values of LDAP host, this functionality is the same as without +// the commas +func (c *Config) getHostPort() (host string, port string) { + for _, address := range strings.Split(c.Host, ",") { + if address == "" { + // If the user sets c.Host as "localhost,", the last value + // will be empty so return the value already in "host" + return host, port + } + var err error + host, port, err = net.SplitHostPort(address) + if err != nil { + host = address + port = secureLdapPort + if c.InsecureNoSSL { + port = ldapPort + } + } + + conn, _ := net.DialTimeout("tcp", net.JoinHostPort(host, port), time.Second) + if conn != nil { + conn.Close() + break + } + } + return host, port +} diff --git a/connector/ldap/ldap_test.go b/connector/ldap/ldap_test.go index 83f9f4790c..b107f16b9e 100644 --- a/connector/ldap/ldap_test.go +++ b/connector/ldap/ldap_test.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "io" + "net" "os" "testing" @@ -538,17 +539,17 @@ func runTests(t *testing.T, connMethod connectionMethod, config *Config, tests [ // group search configuration. switch connMethod { case connectStartTLS: - c.Host = fmt.Sprintf("%s:%s", ldapHost, getenv("DEX_LDAP_PORT", "389")) + c.Host = net.JoinHostPort(ldapHost, getenv("DEX_LDAP_PORT", ldapPort)) c.RootCA = "testdata/certs/ca.crt" c.StartTLS = true case connectLDAPS: - c.Host = fmt.Sprintf("%s:%s", ldapHost, getenv("DEX_LDAP_TLS_PORT", "636")) + c.Host = net.JoinHostPort(ldapHost, getenv("DEX_LDAP_TLS_PORT", secureLdapPort)) c.RootCA = "testdata/certs/ca.crt" case connectInsecureSkipVerify: - c.Host = fmt.Sprintf("%s:%s", ldapHost, getenv("DEX_LDAP_TLS_PORT", "636")) + c.Host = net.JoinHostPort(ldapHost, getenv("DEX_LDAP_TLS_PORT", secureLdapPort)) c.InsecureSkipVerify = true case connectLDAP: - c.Host = fmt.Sprintf("%s:%s", ldapHost, getenv("DEX_LDAP_PORT", "389")) + c.Host = net.JoinHostPort(ldapHost, getenv("DEX_LDAP_PORT", ldapPort)) c.InsecureNoSSL = true } @@ -614,3 +615,67 @@ func runTests(t *testing.T, connMethod connectionMethod, config *Config, tests [ }) } } + +type ghpTestCase struct { + name string + host string + insecure bool + expectedHost string + expectedPort string +} + +func Test_getHostPort(t *testing.T) { + ldapHost := os.Getenv("DEX_LDAP_HOST") + if ldapHost == "" { + t.Skipf(`test environment variable "DEX_LDAP_HOST" not set, skipping`) + } + + offlineLdapHost := os.Getenv("DEX_OFFLINE_LDAP_HOST") + if offlineLdapHost == "" { + t.Skipf(`test environment variable "DEX_OFFLINE_LDAP_HOST" not set, skipping`) + } + tests := []ghpTestCase{ + { + name: "single without port", + host: ldapHost, + insecure: false, + expectedHost: "localhost", + expectedPort: secureLdapPort, + }, + { + name: "multiple without port", + host: fmt.Sprintf("%s,%s", offlineLdapHost, ldapHost), + insecure: true, + expectedHost: ldapHost, + expectedPort: ldapPort, + }, + { + name: "multiple with port", + host: fmt.Sprintf("%s:%s,%s:%s", offlineLdapHost, ldapPort, ldapHost, ldapPort), + insecure: true, + expectedHost: ldapHost, + expectedPort: ldapPort, + }, + { + name: "single with trailing comma", + host: ldapHost + ",", + insecure: false, + expectedHost: ldapHost, + expectedPort: secureLdapPort, + }, + { + name: "single with non-standard port", + host: ldapHost + ":1389", + insecure: false, + expectedHost: ldapHost, + expectedPort: "1389", + }, + } + for _, tc := range tests { + c := &Config{Host: tc.host, InsecureNoSSL: tc.insecure} + actualHost, actualPort := c.getHostPort() + if actualHost != tc.expectedHost || actualPort != tc.expectedPort { + t.Errorf("[%s] expected host:port to be %s:%s but was %s:%s", tc.name, tc.expectedHost, tc.expectedPort, actualHost, actualPort) + } + } +}