Skip to content

Commit

Permalink
Merge pull request #185 from 56quarters/dupes
Browse files Browse the repository at this point in the history
Deduplicate servers by ID when doing service discovery resolution
  • Loading branch information
56quarters authored Aug 21, 2024
2 parents 17d66e8 + 9e1bb52 commit 6ca6368
Showing 1 changed file with 58 additions and 10 deletions.
68 changes: 58 additions & 10 deletions mtop-client/src/discovery.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use crate::core::MtopError;
use crate::dns::{DefaultDnsClient, DnsClient, Name, Record, RecordClass, RecordData, RecordType};
use crate::dns::{DefaultDnsClient, DnsClient, Message, Name, RecordClass, RecordData, RecordType};
use rustls_pki_types::ServerName;
use std::cmp::Ordering;
use std::collections::HashSet;
use std::fmt;
use std::net::{IpAddr, SocketAddr};

Expand Down Expand Up @@ -152,7 +153,7 @@ where
let name = host.parse()?;

let res = self.client.resolve(name, RecordType::SRV, RecordClass::INET).await?;
Ok(Self::servers_from_answers(port, &server_name, res.answers()))
Ok(Self::servers_from_answers(port, &server_name, &res))
}

async fn resolve_a_aaaa(&self, name: &str) -> Result<Vec<Server>, MtopError> {
Expand All @@ -161,10 +162,10 @@ where
let name: Name = host.parse()?;

let res = self.client.resolve(name.clone(), RecordType::A, RecordClass::INET).await?;
let mut out = Self::servers_from_answers(port, &server_name, res.answers());
let mut out = Self::servers_from_answers(port, &server_name, &res);

let res = self.client.resolve(name, RecordType::AAAA, RecordClass::INET).await?;
out.extend(Self::servers_from_answers(port, &server_name, res.answers()));
out.extend(Self::servers_from_answers(port, &server_name, &res));

Ok(out)
}
Expand All @@ -181,10 +182,10 @@ where
Ok(vec![Server::new(ServerID::from((host, port)), server_name)])
}

fn servers_from_answers(port: u16, server_name: &ServerName, answers: &[Record]) -> Vec<Server> {
let mut out = Vec::new();
fn servers_from_answers(port: u16, server_name: &ServerName, message: &Message) -> Vec<Server> {
let mut ids = HashSet::new();

for answer in answers {
for answer in message.answers() {
let server_id = match answer.rdata() {
RecordData::A(data) => ServerID::from(SocketAddr::new(IpAddr::V4(data.addr()), port)),
RecordData::AAAA(data) => ServerID::from(SocketAddr::new(IpAddr::V6(data.addr()), port)),
Expand All @@ -195,11 +196,14 @@ where
}
};

let server = Server::new(server_id, server_name.to_owned());
out.push(server);
// Insert IDs into a HashSet to deduplicate them. We can potentially end up with duplicates
// when a SRV query returns multiple answers per hostname (such as when each host has more
// than a single port). Because we ignore the port number from the SRV answer we need to
// deduplicate here.
ids.insert(server_id);
}

out
ids.into_iter().map(|id| Server::new(id, server_name.to_owned())).collect()
}

fn host_and_port(name: &str) -> Result<(&str, u16), MtopError> {
Expand Down Expand Up @@ -422,6 +426,48 @@ mod test {
assert!(ids.contains(&id2), "expected {:?} to contain {:?}", ids, id2);
}

#[tokio::test]
async fn test_dns_client_resolve_srv_dupes() {
let response = response_with_answers(
RecordType::SRV,
vec![
Record::new(
Name::from_str("_cache.example.com.").unwrap(),
RecordType::SRV,
RecordClass::INET,
300,
RecordData::SRV(RecordDataSRV::new(
100,
10,
11211,
Name::from_str("cache01.example.com.").unwrap(),
)),
),
Record::new(
Name::from_str("_cache.example.com.").unwrap(),
RecordType::SRV,
RecordClass::INET,
300,
RecordData::SRV(RecordDataSRV::new(
100,
10,
9105,
Name::from_str("cache01.example.com.").unwrap(),
)),
),
],
);

let client = MockDnsClient::new(vec![response]);
let discovery = Discovery::new(client);
let servers = discovery.resolve_by_proto("dnssrv+_cache.example.com:11211").await.unwrap();
let ids = servers.iter().map(|s| s.id().clone()).collect::<Vec<_>>();

let id = ServerID::from(("cache01.example.com.", 11211));

assert_eq!(ids, vec![id]);
}

#[tokio::test]
async fn test_dns_client_resolve_socket_addr() {
let name = "127.0.0.2:11211";
Expand All @@ -433,6 +479,7 @@ mod test {
let ids = servers.iter().map(|s| s.id().clone()).collect::<Vec<_>>();

let id = ServerID::from(addr);

assert!(ids.contains(&id), "expected {:?} to contain {:?}", ids, id);
}

Expand All @@ -446,6 +493,7 @@ mod test {
let ids = servers.iter().map(|s| s.id().clone()).collect::<Vec<_>>();

let id = ServerID::from(("localhost", 11211));

assert!(ids.contains(&id), "expected {:?} to contain {:?}", ids, id);
}
}

0 comments on commit 6ca6368

Please sign in to comment.