Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
PIPR-352: Add mxresolv.LookupWithPref
  • Loading branch information
horkhe authored May 20, 2024
2 parents ebe61f9 + e247014 commit 354998d
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 38 deletions.
87 changes: 51 additions & 36 deletions mxresolv/mxresolv.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,54 +39,78 @@ func init() {
lookupResultCache = collections.NewLRUCache(cacheSize)
}

// Lookup performs a DNS lookup of MX records for the specified hostname. It
// Lookup performs a DNS lookup of MX records for the specified domain. It
// returns a prioritised list of MX hostnames, where hostnames with the same
// priority are shuffled. If the second returned value is true, then the host
// does not have explicit MX records, and its A record is returned instead.
//
// It uses an LRU cache with a timeout to reduce the number of network requests.
func Lookup(ctx context.Context, hostname string) (retMxHosts []string, retImplicit bool, reterr error) {
if cachedVal, ok := lookupResultCache.Get(hostname); ok {
func Lookup(ctx context.Context, domain string) (mxHosts []string, implicit bool, err error) {
mxRecords, implicit, err := LookupWithPref(ctx, domain)
if err != nil {
return nil, false, err
}
if len(mxRecords) == 1 {
return []string{mxRecords[0].Host}, implicit, err
}
return shuffleMXRecords(mxRecords), false, nil
}

// LookupWithPref performs a DNS lookup of MX records for the specified domain.
// It returns a slice of net.MX records that are ordered by preference. Records
// with the same preference are sorted by hostname to ensure deterministic
// behaviour. If the second returned value is true, then the host does not have
// explicit MX records, and its A record is used instead.
//
// It uses an LRU cache with a timeout to reduce the number of network requests.
func LookupWithPref(ctx context.Context, domainName string) (mxRecords []*net.MX, implicit bool, err error) {
if cachedVal, ok := lookupResultCache.Get(domainName); ok {
cachedLookupResult := cachedVal.(lookupResult)
if cachedLookupResult.shuffled {
reshuffledMXHosts, _ := shuffleMXRecords(cachedLookupResult.mxRecords)
return reshuffledMXHosts, cachedLookupResult.implicit, cachedLookupResult.err
}
return cachedLookupResult.mxHosts, cachedLookupResult.implicit, cachedLookupResult.err
return cachedLookupResult.mxRecords, cachedLookupResult.implicit, cachedLookupResult.err
}

asciiHostname, err := ensureASCII(hostname)
asciiDomainName, err := ensureASCII(domainName)
if err != nil {
return nil, false, errors.Wrap(err, "invalid hostname")
return nil, false, errors.Wrap(err, "invalid domain name")
}
mxRecords, err := lookupMX(Resolver, ctx, asciiHostname)
mxRecords, err = lookupMX(Resolver, ctx, asciiDomainName)
if err != nil {
var timeouter interface{ Timeout() bool }
if errors.As(err, &timeouter) && timeouter.Timeout() {
return nil, false, errors.WithStack(err)
}
var netDNSError *net.DNSError
if errors.As(err, &netDNSError) && netDNSError.IsNotFound {
if _, err := Resolver.LookupIPAddr(ctx, asciiHostname); err != nil {
return cacheAndReturn(hostname, nil, nil, false, false, errors.WithStack(err))
if _, err := Resolver.LookupIPAddr(ctx, asciiDomainName); err != nil {
return cacheAndReturn(domainName, nil, false, errors.WithStack(err))
}
return cacheAndReturn(hostname, []string{asciiHostname}, nil, false, true, nil)
return cacheAndReturn(domainName, []*net.MX{{Host: asciiDomainName, Pref: 1}}, true, nil)
}
if mxRecords == nil {
return cacheAndReturn(hostname, nil, nil, false, false, errors.WithStack(err))
return cacheAndReturn(domainName, nil, false, errors.WithStack(err))
}
}
// Check for "Null MX" record (https://tools.ietf.org/html/rfc7505).
if len(mxRecords) == 1 {
if mxRecords[0].Host == "." {
return cacheAndReturn(hostname, nil, nil, false, false, errNullMXRecord)
return cacheAndReturn(domainName, nil, false, errNullMXRecord)
}
// 0.0.0.0 is not really a "Null MX" record, but some people apparently
// have never heard of RFC7505 and configure it this way.
if strings.HasPrefix(mxRecords[0].Host, "0.0.0.0") {
return cacheAndReturn(hostname, nil, nil, false, false, errNullMXRecord)
return cacheAndReturn(domainName, nil, false, errNullMXRecord)
}
}
// Purge records with non-ASCII characters. we have seen such records in
// production, they are obviously products of human errors.
for i := 0; i < len(mxRecords); {
if isASCII(mxRecords[i].Host) {
i++
continue
}
copy(mxRecords[i:], mxRecords[i+1:])
mxRecords = mxRecords[:len(mxRecords)-1]
}
// If there are no valid records left, then return an error.
if len(mxRecords) == 0 {
return cacheAndReturn(domainName, nil, false, errNoValidMXHosts)
}
// Normalize returned hostnames: drop trailing '.' and lowercase.
for _, mxRecord := range mxRecords {
lastCharIndex := len(mxRecord.Host) - 1
Expand All @@ -100,11 +124,7 @@ func Lookup(ctx context.Context, hostname string) (retMxHosts []string, retImpli
return mxRecords[i].Pref < mxRecords[j].Pref ||
(mxRecords[i].Pref == mxRecords[j].Pref && mxRecords[i].Host < mxRecords[j].Host)
})
mxHosts, shuffled := shuffleMXRecords(mxRecords)
if len(mxHosts) == 0 {
return cacheAndReturn(hostname, nil, nil, false, false, errNoValidMXHosts)
}
return cacheAndReturn(hostname, mxHosts, mxRecords, shuffled, false, nil)
return cacheAndReturn(domainName, mxRecords, false, nil)
}

// SetDeterministicInTests sets rand to deterministic seed for testing, and is
Expand All @@ -126,14 +146,13 @@ func ResetCache() {
lookupResultCache = collections.NewLRUCache(1000)
}

func shuffleMXRecords(mxRecords []*net.MX) ([]string, bool) {
func shuffleMXRecords(mxRecords []*net.MX) []string {
// Shuffle the hosts within the preference groups.
var (
mxHosts []string
groupBegin = 0
groupEnd = 0
groupPref uint16
shuffled = false
)
for _, mxRecord := range mxRecords {
// If a hostname has non-ASCII characters then ignore it, for it is
Expand Down Expand Up @@ -165,7 +184,6 @@ func shuffleMXRecords(mxRecords []*net.MX) ([]string, bool) {
// After finding the end of the current preference group, shuffle it.
if groupEnd-groupBegin > 1 {
shuffleHosts(mxHosts[groupBegin:groupEnd])
shuffled = true
}
// Set up the next preference group.
groupBegin = groupEnd
Expand All @@ -175,9 +193,8 @@ func shuffleMXRecords(mxRecords []*net.MX) ([]string, bool) {
// Shuffle the last preference group, if there is one.
if groupEnd-groupBegin > 1 {
shuffleHosts(mxHosts[groupBegin:groupEnd])
shuffled = true
}
return mxHosts, shuffled
return mxHosts
}

func shuffleHosts(hosts []string) {
Expand Down Expand Up @@ -208,15 +225,13 @@ func isASCII(s string) bool {

type lookupResult struct {
mxRecords []*net.MX
mxHosts []string
shuffled bool
implicit bool
err error
}

func cacheAndReturn(hostname string, mxHosts []string, mxRecords []*net.MX, shuffled, implicit bool, err error) (retMxHosts []string, retImplicit bool, reterr error) {
lookupResultCache.AddWithTTL(hostname, lookupResult{mxHosts: mxHosts, mxRecords: mxRecords, shuffled: shuffled, implicit: implicit, err: err}, cacheTTL)
return mxHosts, implicit, err
func cacheAndReturn(hostname string, mxRecords []*net.MX, implicit bool, err error) ([]*net.MX, bool, error) {
lookupResultCache.AddWithTTL(hostname, lookupResult{mxRecords: mxRecords, implicit: implicit, err: err}, cacheTTL)
return mxRecords, implicit, err
}

// lookupMX exposes the respective private function of net.Resolver. The public
Expand Down
74 changes: 72 additions & 2 deletions mxresolv/mxresolv_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,76 @@ func TestMain(m *testing.M) {
os.Exit(exitVal)
}

func TestLookupWithPref(t *testing.T) {
for _, tc := range []struct {
desc string
inDomainName string
outMXHosts []*net.MX
outImplicitMX bool
}{{
desc: "MX record preference is respected",
inDomainName: "test-mx.definbox.com",
outMXHosts: []*net.MX{
{Host: "mxa.definbox.com", Pref: 1}, {Host: "mxe.definbox.com", Pref: 1}, {Host: "mxi.definbox.com", Pref: 1},
{Host: "mxc.definbox.com", Pref: 2},
{Host: "mxb.definbox.com", Pref: 3}, {Host: "mxd.definbox.com", Pref: 3}, {Host: "mxf.definbox.com", Pref: 3}, {Host: "mxg.definbox.com", Pref: 3}, {Host: "mxh.definbox.com", Pref: 3},
},
outImplicitMX: false,
}, {
inDomainName: "test-a.definbox.com",
outMXHosts: []*net.MX{{Host: "test-a.definbox.com", Pref: 1}},
outImplicitMX: true,
}, {
inDomainName: "test-cname.definbox.com",
outMXHosts: []*net.MX{{Host: "mxa.ninomail.com", Pref: 10}, {Host: "mxb.ninomail.com", Pref: 10}},
outImplicitMX: false,
}, {
inDomainName: "definbox.com",
outMXHosts: []*net.MX{{Host: "mxa.ninomail.com", Pref: 10}, {Host: "mxb.ninomail.com", Pref: 10}},
outImplicitMX: false,
}, {
desc: "If an MX host returned by the resolver contains non ASCII " +
"characters then it is silently dropped from the returned list",
inDomainName: "test-unicode.definbox.com",
outMXHosts: []*net.MX{{Host: "mxa.definbox.com", Pref: 1}, {Host: "mxb.definbox.com", Pref: 3}},
outImplicitMX: false,
}, {
desc: "Underscore is allowed in domain names",
inDomainName: "test-underscore.definbox.com",
outMXHosts: []*net.MX{{Host: "foo_bar.definbox.com", Pref: 1}},
outImplicitMX: false,
}, {
inDomainName: "test-яндекс.definbox.com",
outMXHosts: []*net.MX{{Host: "xn--test---mofb0ab4b8camvcmn8gxd.definbox.com", Pref: 10}},
outImplicitMX: false,
}, {
inDomainName: "xn--test--xweh4bya7b6j.definbox.com",
outMXHosts: []*net.MX{{Host: "xn--test---mofb0ab4b8camvcmn8gxd.definbox.com", Pref: 10}},
outImplicitMX: false,
}, {
inDomainName: "test-mx-ipv4.definbox.com",
outMXHosts: []*net.MX{{Host: "34.150.176.225", Pref: 10}},
outImplicitMX: false,
}, {
inDomainName: "test-mx-ipv6.definbox.com",
outMXHosts: []*net.MX{{Host: "::ffff:2296:b0e1", Pref: 10}},
outImplicitMX: false,
}} {
t.Run(tc.inDomainName, func(t *testing.T) {
defer mxresolv.SetDeterministicInTests()()

// When
ctx, cancel := context.WithTimeout(context.Background(), 3*clock.Second)
defer cancel()
mxRecords, implicitMX, err := mxresolv.LookupWithPref(ctx, tc.inDomainName)
// Then
assert.NoError(t, err)
assert.Equal(t, tc.outMXHosts, mxRecords)
assert.Equal(t, tc.outImplicitMX, implicitMX)
})
}
}

func TestLookup(t *testing.T) {
for _, tc := range []struct {
desc string
Expand Down Expand Up @@ -172,11 +242,11 @@ func TestLookup(t *testing.T) {
// When
ctx, cancel := context.WithTimeout(context.Background(), 3*clock.Second)
defer cancel()
mxHosts, explictMX, err := mxresolv.Lookup(ctx, tc.inDomainName)
mxHosts, implicitMX, err := mxresolv.Lookup(ctx, tc.inDomainName)
// Then
assert.NoError(t, err)
assert.Equal(t, tc.outMXHosts, mxHosts)
assert.Equal(t, tc.outImplicitMX, explictMX)
assert.Equal(t, tc.outImplicitMX, implicitMX)
})
}
}
Expand Down

0 comments on commit 354998d

Please sign in to comment.