Skip to content

Commit

Permalink
Support multiple LDAP hosts
Browse files Browse the repository at this point in the history
Often there is a need to iterate through multiple LDAP hosts in the
event that one has failed. DNS is one possible solution, but there
are cases where that may not be desired. This change allows users to
specify a comma-separated list of LDAP hosts to iterate through in case
one is unreachable.
  • Loading branch information
onetwopunch authored and Ryan Canty committed Jan 6, 2023
1 parent 6dbd14b commit 3b3730d
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 18 deletions.
55 changes: 41 additions & 14 deletions connector/ldap/ldap.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,20 @@ import (
"fmt"
"net"
"os"
"strings"
"time"

"github.com/go-ldap/ldap/v3"

"github.com/dexidp/dex/connector"
"github.com/dexidp/dex/pkg/log"
)

const (
secureLdapPort = "636"
ldapPort = "389"
)

// Config holds the configuration parameters for the LDAP connector. The LDAP
// connectors require executing two queries, the first to find the user based on
// the username and password given to the connector. The second to use the user
Expand Down Expand Up @@ -65,7 +72,8 @@ type UserMatcher struct {
// Config holds configuration options for LDAP logins.
type Config struct {
// The host and optional port of the LDAP server. If port isn't supplied, it will be
// guessed based on the TLS configuration. 389 or 636.
// guessed based on the TLS configuration. 389 or 636. Can be a comma-separated list
// of hosts that will be tried iteratively.
Host string `json:"host"`

// Required if LDAP host does not use TLS.
Expand Down Expand Up @@ -241,19 +249,8 @@ func (c *Config) openConnector(logger log.Logger) (*ldapConnector, error) {
}
}

var (
host string
err error
)
if host, _, err = net.SplitHostPort(c.Host); err != nil {
host = c.Host
if c.InsecureNoSSL {
c.Host += ":389"
} else {
c.Host += ":636"
}
}

host, port := c.getHostPort()
c.Host = net.JoinHostPort(host, port)
tlsConfig := &tls.Config{ServerName: host, InsecureSkipVerify: c.InsecureSkipVerify}
if c.RootCA != "" || len(c.RootCAData) != 0 {
data := c.RootCAData
Expand Down Expand Up @@ -635,3 +632,33 @@ func (c *ldapConnector) groups(ctx context.Context, user ldap.Entry) ([]string,
func (c *ldapConnector) Prompt() string {
return c.UsernamePrompt
}

// getHostPort splits the Host attribute on comma and tests each connection
// if it fails, the next is set. The final host, port is returned if none succeed.
// For single values of LDAP host, this functionality is the same as without
// the commas
func (c *Config) getHostPort() (host string, port string) {
for _, address := range strings.Split(c.Host, ",") {
if address == "" {
// If the user sets c.Host as "localhost,", the last value
// will be empty so return the value already in "host"
return host, port
}
var err error
host, port, err = net.SplitHostPort(address)
if err != nil {
host = address
port = secureLdapPort
if c.InsecureNoSSL {
port = ldapPort
}
}

conn, _ := net.DialTimeout("tcp", net.JoinHostPort(host, port), time.Second)
if conn != nil {
conn.Close()
break
}
}
return host, port
}
73 changes: 69 additions & 4 deletions connector/ldap/ldap_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"fmt"
"io"
"net"
"os"
"testing"

Expand Down Expand Up @@ -538,17 +539,17 @@ func runTests(t *testing.T, connMethod connectionMethod, config *Config, tests [
// group search configuration.
switch connMethod {
case connectStartTLS:
c.Host = fmt.Sprintf("%s:%s", ldapHost, getenv("DEX_LDAP_PORT", "389"))
c.Host = net.JoinHostPort(ldapHost, getenv("DEX_LDAP_PORT", ldapPort))
c.RootCA = "testdata/certs/ca.crt"
c.StartTLS = true
case connectLDAPS:
c.Host = fmt.Sprintf("%s:%s", ldapHost, getenv("DEX_LDAP_TLS_PORT", "636"))
c.Host = net.JoinHostPort(ldapHost, getenv("DEX_LDAP_TLS_PORT", secureLdapPort))
c.RootCA = "testdata/certs/ca.crt"
case connectInsecureSkipVerify:
c.Host = fmt.Sprintf("%s:%s", ldapHost, getenv("DEX_LDAP_TLS_PORT", "636"))
c.Host = net.JoinHostPort(ldapHost, getenv("DEX_LDAP_TLS_PORT", secureLdapPort))
c.InsecureSkipVerify = true
case connectLDAP:
c.Host = fmt.Sprintf("%s:%s", ldapHost, getenv("DEX_LDAP_PORT", "389"))
c.Host = net.JoinHostPort(ldapHost, getenv("DEX_LDAP_PORT", ldapPort))
c.InsecureNoSSL = true
}

Expand Down Expand Up @@ -614,3 +615,67 @@ func runTests(t *testing.T, connMethod connectionMethod, config *Config, tests [
})
}
}

type ghpTestCase struct {
name string
host string
insecure bool
expectedHost string
expectedPort string
}

func Test_getHostPort(t *testing.T) {
ldapHost := os.Getenv("DEX_LDAP_HOST")
if ldapHost == "" {
t.Skipf(`test environment variable "DEX_LDAP_HOST" not set, skipping`)
}

offlineLdapHost := os.Getenv("DEX_OFFLINE_LDAP_HOST")
if offlineLdapHost == "" {
t.Skipf(`test environment variable "DEX_OFFLINE_LDAP_HOST" not set, skipping`)
}
tests := []ghpTestCase{
{
name: "single without port",
host: ldapHost,
insecure: false,
expectedHost: "localhost",
expectedPort: secureLdapPort,
},
{
name: "multiple without port",
host: fmt.Sprintf("%s,%s", offlineLdapHost, ldapHost),
insecure: true,
expectedHost: ldapHost,
expectedPort: ldapPort,
},
{
name: "multiple with port",
host: fmt.Sprintf("%s:%s,%s:%s", offlineLdapHost, ldapPort, ldapHost, ldapPort),
insecure: true,
expectedHost: ldapHost,
expectedPort: ldapPort,
},
{
name: "single with trailing comma",
host: ldapHost + ",",
insecure: false,
expectedHost: ldapHost,
expectedPort: secureLdapPort,
},
{
name: "single with non-standard port",
host: ldapHost + ":1389",
insecure: false,
expectedHost: ldapHost,
expectedPort: "1389",
},
}
for _, tc := range tests {
c := &Config{Host: tc.host, InsecureNoSSL: tc.insecure}
actualHost, actualPort := c.getHostPort()
if actualHost != tc.expectedHost || actualPort != tc.expectedPort {
t.Errorf("[%s] expected host:port to be %s:%s but was %s:%s", tc.name, tc.expectedHost, tc.expectedPort, actualHost, actualPort)
}
}
}

0 comments on commit 3b3730d

Please sign in to comment.