diff --git a/connector/ldap/ldap.go b/connector/ldap/ldap.go index 543402718c..a009128cf6 100644 --- a/connector/ldap/ldap.go +++ b/connector/ldap/ldap.go @@ -9,6 +9,8 @@ import ( "fmt" "net" "os" + "strings" + "time" "github.com/go-ldap/ldap/v3" @@ -16,6 +18,11 @@ import ( "github.com/dexidp/dex/pkg/log" ) +const ( + secureLdapPort = "636" + ldapPort = "389" +) + // Config holds the configuration parameters for the LDAP connector. The LDAP // connectors require executing two queries, the first to find the user based on // the username and password given to the connector. The second to use the user @@ -65,7 +72,8 @@ type UserMatcher struct { // Config holds configuration options for LDAP logins. type Config struct { // The host and optional port of the LDAP server. If port isn't supplied, it will be - // guessed based on the TLS configuration. 389 or 636. + // guessed based on the TLS configuration. 389 or 636. Can be a comma-separated list + // of hosts that will be tried iteratively. Host string `json:"host"` // Required if LDAP host does not use TLS. @@ -241,19 +249,8 @@ func (c *Config) openConnector(logger log.Logger) (*ldapConnector, error) { } } - var ( - host string - err error - ) - if host, _, err = net.SplitHostPort(c.Host); err != nil { - host = c.Host - if c.InsecureNoSSL { - c.Host += ":389" - } else { - c.Host += ":636" - } - } - + host, port := c.getHostPort() + c.Host = net.JoinHostPort(host, port) tlsConfig := &tls.Config{ServerName: host, InsecureSkipVerify: c.InsecureSkipVerify} if c.RootCA != "" || len(c.RootCAData) != 0 { data := c.RootCAData @@ -635,3 +632,33 @@ func (c *ldapConnector) groups(ctx context.Context, user ldap.Entry) ([]string, func (c *ldapConnector) Prompt() string { return c.UsernamePrompt } + +// getHostPort splits the Host attribute on comma and tests each connection +// if it fails, the next is set. The final host, port is returned if none succeed. +// For single values of LDAP host, this functionality is the same as without +// the commas +func (c *Config) getHostPort() (host string, port string) { + for _, address := range strings.Split(c.Host, ",") { + if address == "" { + // If the user sets c.Host as "localhost,", the last value + // will be empty so return the value already in "host" + return host, port + } + var err error + host, port, err = net.SplitHostPort(address) + if err != nil { + host = address + port = secureLdapPort + if c.InsecureNoSSL { + port = ldapPort + } + } + + conn, _ := net.DialTimeout("tcp", net.JoinHostPort(host, port), time.Second) + if conn != nil { + conn.Close() + break + } + } + return host, port +} diff --git a/connector/ldap/ldap_test.go b/connector/ldap/ldap_test.go index 83f9f4790c..b107f16b9e 100644 --- a/connector/ldap/ldap_test.go +++ b/connector/ldap/ldap_test.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "io" + "net" "os" "testing" @@ -538,17 +539,17 @@ func runTests(t *testing.T, connMethod connectionMethod, config *Config, tests [ // group search configuration. switch connMethod { case connectStartTLS: - c.Host = fmt.Sprintf("%s:%s", ldapHost, getenv("DEX_LDAP_PORT", "389")) + c.Host = net.JoinHostPort(ldapHost, getenv("DEX_LDAP_PORT", ldapPort)) c.RootCA = "testdata/certs/ca.crt" c.StartTLS = true case connectLDAPS: - c.Host = fmt.Sprintf("%s:%s", ldapHost, getenv("DEX_LDAP_TLS_PORT", "636")) + c.Host = net.JoinHostPort(ldapHost, getenv("DEX_LDAP_TLS_PORT", secureLdapPort)) c.RootCA = "testdata/certs/ca.crt" case connectInsecureSkipVerify: - c.Host = fmt.Sprintf("%s:%s", ldapHost, getenv("DEX_LDAP_TLS_PORT", "636")) + c.Host = net.JoinHostPort(ldapHost, getenv("DEX_LDAP_TLS_PORT", secureLdapPort)) c.InsecureSkipVerify = true case connectLDAP: - c.Host = fmt.Sprintf("%s:%s", ldapHost, getenv("DEX_LDAP_PORT", "389")) + c.Host = net.JoinHostPort(ldapHost, getenv("DEX_LDAP_PORT", ldapPort)) c.InsecureNoSSL = true } @@ -614,3 +615,67 @@ func runTests(t *testing.T, connMethod connectionMethod, config *Config, tests [ }) } } + +type ghpTestCase struct { + name string + host string + insecure bool + expectedHost string + expectedPort string +} + +func Test_getHostPort(t *testing.T) { + ldapHost := os.Getenv("DEX_LDAP_HOST") + if ldapHost == "" { + t.Skipf(`test environment variable "DEX_LDAP_HOST" not set, skipping`) + } + + offlineLdapHost := os.Getenv("DEX_OFFLINE_LDAP_HOST") + if offlineLdapHost == "" { + t.Skipf(`test environment variable "DEX_OFFLINE_LDAP_HOST" not set, skipping`) + } + tests := []ghpTestCase{ + { + name: "single without port", + host: ldapHost, + insecure: false, + expectedHost: "localhost", + expectedPort: secureLdapPort, + }, + { + name: "multiple without port", + host: fmt.Sprintf("%s,%s", offlineLdapHost, ldapHost), + insecure: true, + expectedHost: ldapHost, + expectedPort: ldapPort, + }, + { + name: "multiple with port", + host: fmt.Sprintf("%s:%s,%s:%s", offlineLdapHost, ldapPort, ldapHost, ldapPort), + insecure: true, + expectedHost: ldapHost, + expectedPort: ldapPort, + }, + { + name: "single with trailing comma", + host: ldapHost + ",", + insecure: false, + expectedHost: ldapHost, + expectedPort: secureLdapPort, + }, + { + name: "single with non-standard port", + host: ldapHost + ":1389", + insecure: false, + expectedHost: ldapHost, + expectedPort: "1389", + }, + } + for _, tc := range tests { + c := &Config{Host: tc.host, InsecureNoSSL: tc.insecure} + actualHost, actualPort := c.getHostPort() + if actualHost != tc.expectedHost || actualPort != tc.expectedPort { + t.Errorf("[%s] expected host:port to be %s:%s but was %s:%s", tc.name, tc.expectedHost, tc.expectedPort, actualHost, actualPort) + } + } +}