Skip to content

Commit

Permalink
Make DefaultDnsClient generic for easier testing
Browse files Browse the repository at this point in the history
Make the UDP and TCP client factories generic parameters for the DNS client
to allow them to be replaced for unit testing. This change also renames the
UDP and TCP "clients" to "connections" to emphasize that they are low level
and just send messages back and forth.
  • Loading branch information
56quarters committed Sep 2, 2024
1 parent efc9725 commit d378c60
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 81 deletions.
155 changes: 84 additions & 71 deletions mtop-client/src/dns/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,17 +81,25 @@ pub trait DnsClient {
/// operation. This means that a single call to `resolve` make take longer
/// than the timeout since failed network operations are retried.
#[derive(Debug)]
pub struct DefaultDnsClient {
pub struct DefaultDnsClient<U = UdpConnectionFactory, T = TcpConnectionFactory>
where
U: ClientFactory<SocketAddr, UdpConnection> + Send + Sync + 'static,
T: ClientFactory<SocketAddr, TcpConnection> + Send + Sync + 'static,
{
config: DnsClientConfig,
server_idx: AtomicUsize,
udp_pool: ClientPool<SocketAddr, UdpClient, UdpFactory>,
tcp_pool: ClientPool<SocketAddr, TcpClient, TcpFactory>,
udp_pool: ClientPool<SocketAddr, UdpConnection, U>,
tcp_pool: ClientPool<SocketAddr, TcpConnection, T>,
}

impl DefaultDnsClient {
impl<U, T> DefaultDnsClient<U, T>
where
U: ClientFactory<SocketAddr, UdpConnection> + Send + Sync + 'static,
T: ClientFactory<SocketAddr, TcpConnection> + Send + Sync + 'static,
{
/// Create a new DnsClient that will resolve names using UDP or TCP connections
/// and behavior based on a resolv.conf configuration file.
pub fn new(config: DnsClientConfig) -> Self {
pub fn new(config: DnsClientConfig, udp_factory: U, tcp_factory: T) -> Self {
let udp_config = ClientPoolConfig {
name: "dns-udp".to_owned(),
max_idle: config.pool_max_idle,
Expand All @@ -105,10 +113,11 @@ impl DefaultDnsClient {
Self {
config,
server_idx: AtomicUsize::new(0),
udp_pool: ClientPool::new(udp_config, UdpFactory),
tcp_pool: ClientPool::new(tcp_config, TcpFactory),
udp_pool: ClientPool::new(udp_config, udp_factory),
tcp_pool: ClientPool::new(tcp_config, tcp_factory),
}
}

async fn exchange(&self, msg: &Message, attempt: usize) -> Result<Message, MtopError> {
let server = self.nameserver(attempt);

Expand Down Expand Up @@ -155,7 +164,11 @@ impl DefaultDnsClient {
}
}

impl DnsClient for DefaultDnsClient {
impl<U, T> DnsClient for DefaultDnsClient<U, T>
where
U: ClientFactory<SocketAddr, UdpConnection> + Send + Sync + 'static,
T: ClientFactory<SocketAddr, TcpConnection> + Send + Sync + 'static,
{
async fn resolve(&self, name: Name, rtype: RecordType, rclass: RecordClass) -> Result<Message, MtopError> {
let full = name.to_fqdn();
let id = MessageId::random();
Expand All @@ -180,48 +193,45 @@ impl DnsClient for DefaultDnsClient {
}
}

/// Client for sending and receiving DNS messages over read and write streams,
/// usually a TCP connection. Messages are sent with a two byte prefix that
/// indicates the size of the message. Responses are expected to have the same
/// prefix. The message ID of responses is checked to ensure it matches the request
/// ID. If it does not, an error is returned.
struct TcpClient {
/// Connection for unconditionally sending and receiving DNS messages using TCP streams.
/// Messages are sent with a two byte prefix that indicates the size of the message.
/// Responses are expected to have the same prefix. The message ID of responses is
/// checked to ensure it matches the request ID. If it does not, an error is returned.
pub struct TcpConnection {
read: BufReader<Box<dyn AsyncRead + Send + Sync + Unpin>>,
write: BufWriter<Box<dyn AsyncWrite + Send + Sync + Unpin>>,
size: usize,
buffer: Vec<u8>,
}

impl TcpClient {
fn new<R, W>(read: R, write: W, size: usize) -> Self
impl TcpConnection {
pub fn new<R, W>(read: R, write: W, size: usize) -> Self
where
R: AsyncRead + Unpin + Sync + Send + 'static,
W: AsyncWrite + Unpin + Sync + Send + 'static,
{
Self {
read: BufReader::new(Box::new(read)),
write: BufWriter::new(Box::new(write)),
size,
buffer: Vec::with_capacity(size),
}
}

async fn exchange(&mut self, msg: &Message) -> Result<Message, MtopError> {
let mut buf = Vec::with_capacity(self.size);

pub async fn exchange(&mut self, msg: &Message) -> Result<Message, MtopError> {
// Write the message to a local buffer and then send it, prefixed
// with the size of the message.
msg.write_network_bytes(&mut buf)?;
self.write.write_u16(buf.len() as u16).await?;
self.write.write_all(&buf).await?;
msg.write_network_bytes(&mut self.buffer)?;
self.write.write_u16(self.buffer.len() as u16).await?;
self.write.write_all(&self.buffer).await?;
self.write.flush().await?;

// Read the prefixed size of the response in big-endian (network)
// order and then read exactly that many bytes into our buffer.
let sz = self.read.read_u16().await?;
buf.clear();
buf.resize(usize::from(sz), 0);
self.read.read_exact(&mut buf).await?;
self.buffer.clear();
self.buffer.resize(usize::from(sz), 0);
self.read.read_exact(&mut self.buffer).await?;

let mut cur = Cursor::new(buf);
let mut cur = Cursor::new(&self.buffer);
let res = Message::read_network_bytes(&mut cur)?;
if res.id() != msg.id() {
Err(MtopError::runtime(format!(
Expand All @@ -235,56 +245,59 @@ impl TcpClient {
}
}

impl fmt::Debug for TcpClient {
impl fmt::Debug for TcpConnection {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(f, "TcpClient {{ read: ..., write: ..., size: {} }}", self.size)
write!(f, "TcpConnection {{ ... }}")
}
}

/// Client for sending and receiving DNS messages over read and write streams,
/// usually an adapter over a UDP socket. The message ID of responses is checked
/// to ensure it matches the request ID. If it does not, the response is discarded
/// and the client will wait for another response until it gets one with a matching
/// ID.
struct UdpClient {
/// Connection for unconditionally sending and receiving DNS messages using UDP packets.
/// The message ID of responses is checked to ensure it matches the request ID. If it
/// does not, the response is discarded and the client will wait for another response
/// until it gets one with a matching ID.
pub struct UdpConnection {
read: Box<dyn AsyncRead + Send + Sync + Unpin>,
write: Box<dyn AsyncWrite + Send + Sync + Unpin>,
size: usize,
buffer: Vec<u8>,
packet: usize,
}

impl UdpClient {
fn new<R, W>(read: R, write: W, size: usize) -> Self
impl UdpConnection {
pub fn new<R, W>(read: R, write: W, size: usize) -> Self
where
R: AsyncRead + Unpin + Sync + Send + 'static,
W: AsyncWrite + Unpin + Sync + Send + 'static,
{
Self {
read: Box::new(read),
write: Box::new(write),
size,
buffer: Vec::with_capacity(size),
packet: size,
}
}

async fn exchange(&mut self, msg: &Message) -> Result<Message, MtopError> {
let mut buf = Vec::with_capacity(self.size);
msg.write_network_bytes(&mut buf)?;
pub async fn exchange(&mut self, msg: &Message) -> Result<Message, MtopError> {
self.buffer.clear();
msg.write_network_bytes(&mut self.buffer)?;
// We expect this to be a datagram socket so we only do a single write.
let n = self.write.write(&buf).await?;
if n != buf.len() {
let n = self.write.write(&self.buffer).await?;
if n != self.buffer.len() {
return Err(MtopError::runtime(format!(
"short write to UDP socket. expected {}, got {}",
buf.len(),
self.buffer.len(),
n
)));
}
self.write.flush().await?;

buf.clear();
buf.resize(self.size, 0);
// Resize to our packet size since the .read() call will only read up to
// the size of the buffer at most.
self.buffer.clear();
self.buffer.resize(self.packet, 0);

loop {
let n = self.read.read(&mut buf).await?;
let cur = Cursor::new(&buf[0..n]);
let n = self.read.read(&mut self.buffer).await?;
let cur = Cursor::new(&self.buffer[0..n]);
let res = Message::read_network_bytes(cur)?;
if res.id() == msg.id() {
return Ok(res);
Expand All @@ -293,9 +306,9 @@ impl UdpClient {
}
}

impl fmt::Debug for UdpClient {
impl fmt::Debug for UdpConnection {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(f, "UdpClient {{ read: ..., write: ..., size: {} }}", self.size)
write!(f, "UdpConnection {{ ... }}")
}
}

Expand Down Expand Up @@ -330,38 +343,38 @@ impl AsyncWrite for SocketAdapter {
}
}

/// Implementation of `ClientFactory` for creating concrete `UdpClient` instances
/// that use UDP sockets.
/// Implementation of `ClientFactory` for creating concrete `UdpConnection` instances
/// that use a UDP socket.
#[derive(Debug, Clone, Default)]
struct UdpFactory;
pub struct UdpConnectionFactory;

impl ClientFactory<SocketAddr, UdpClient> for UdpFactory {
async fn make(&self, address: &SocketAddr) -> Result<UdpClient, MtopError> {
impl ClientFactory<SocketAddr, UdpConnection> for UdpConnectionFactory {
async fn make(&self, address: &SocketAddr) -> Result<UdpConnection, MtopError> {
let local = if address.is_ipv4() { "0.0.0.0:0" } else { "[::]:0" };
let sock = UdpSocket::bind(local).await?;
sock.connect(address).await?;

let adapter = SocketAdapter::new(sock);
let (read, write) = tokio::io::split(adapter);
Ok(UdpClient::new(read, write, DEFAULT_MESSAGE_BUFFER))
Ok(UdpConnection::new(read, write, DEFAULT_MESSAGE_BUFFER))
}
}

/// Implementation of `ClientFactory` for creating concrete `TcpClient` instances
/// that uses a split TCP socket.
/// Implementation of `ClientFactory` for creating concrete `TcpConnection` instances
/// that use a TCP socket.
#[derive(Debug, Clone, Default)]
struct TcpFactory;
pub struct TcpConnectionFactory;

impl ClientFactory<SocketAddr, TcpClient> for TcpFactory {
async fn make(&self, address: &SocketAddr) -> Result<TcpClient, MtopError> {
impl ClientFactory<SocketAddr, TcpConnection> for TcpConnectionFactory {
async fn make(&self, address: &SocketAddr) -> Result<TcpConnection, MtopError> {
let (read, write) = tcp_connect(address).await?;
Ok(TcpClient::new(read, write, DEFAULT_MESSAGE_BUFFER))
Ok(TcpConnection::new(read, write, DEFAULT_MESSAGE_BUFFER))
}
}

#[cfg(test)]
mod test {
use super::{TcpClient, UdpClient};
use super::{TcpConnection, UdpConnection};
use crate::core::ErrorKind;
use crate::dns::core::{RecordClass, RecordType};
use crate::dns::message::{Flags, Message, MessageId, Question, Record};
Expand Down Expand Up @@ -432,7 +445,7 @@ mod test {
let message =
Message::new(MessageId::from(123), Flags::default().set_recursion_desired()).add_question(question);

let mut client = TcpClient::new(read, write, 512);
let mut client = TcpConnection::new(read, write, 512);
let res = client.exchange(&message).await;
let err = res.unwrap_err();

Expand All @@ -450,7 +463,7 @@ mod test {
let message =
Message::new(MessageId::from(123), Flags::default().set_recursion_desired()).add_question(question);

let mut client = TcpClient::new(read, write, 512);
let mut client = TcpConnection::new(read, write, 512);
let res = client.exchange(&message).await;
let err = res.unwrap_err();

Expand All @@ -466,7 +479,7 @@ mod test {
let message =
Message::new(MessageId::from(123), Flags::default().set_recursion_desired()).add_question(question);

let mut client = TcpClient::new(read, write, 512);
let mut client = TcpConnection::new(read, write, 512);
let res = client.exchange(&message).await;
let err = res.unwrap_err();

Expand All @@ -482,7 +495,7 @@ mod test {
let message =
Message::new(MessageId::from(123), Flags::default().set_recursion_desired()).add_question(question);

let mut client = TcpClient::new(read, write, 512);
let mut client = TcpConnection::new(read, write, 512);
let res = client.exchange(&message).await.unwrap();

assert_eq!(message.id(), res.id());
Expand Down Expand Up @@ -542,7 +555,7 @@ mod test {
let message =
Message::new(MessageId::from(123), Flags::default().set_recursion_desired()).add_question(question);

let mut client = UdpClient::new(read, write, 512);
let mut client = UdpConnection::new(read, write, 512);
let res = client.exchange(&message).await.unwrap();

assert_eq!(message.id(), res.id());
Expand Down Expand Up @@ -570,7 +583,7 @@ mod test {
let message =
Message::new(MessageId::from(123), Flags::default().set_recursion_desired()).add_question(question);

let mut client = UdpClient::new(read, write, 512);
let mut client = UdpConnection::new(read, write, 512);
let res = client.exchange(&message).await.unwrap();

assert_eq!(message.id(), res.id());
Expand Down
8 changes: 4 additions & 4 deletions mtop-client/src/dns/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,10 @@ impl Message {
}

fn header(&self) -> Header {
assert!(self.questions.len() < u16::MAX as usize);
assert!(self.answers.len() < u16::MAX as usize);
assert!(self.authority.len() < u16::MAX as usize);
assert!(self.extra.len() < u16::MAX as usize);
assert!(self.questions.len() < usize::from(u16::MAX));
assert!(self.answers.len() < usize::from(u16::MAX));
assert!(self.authority.len() < usize::from(u16::MAX));
assert!(self.extra.len() < usize::from(u16::MAX));

Header {
id: self.id,
Expand Down
5 changes: 4 additions & 1 deletion mtop-client/src/dns/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@ mod name;
mod rdata;
mod resolv;

pub use crate::dns::client::{DefaultDnsClient, DnsClient, DnsClientConfig};
pub use crate::dns::client::{
DefaultDnsClient, DnsClient, DnsClientConfig, TcpConnection, TcpConnectionFactory, UdpConnection,
UdpConnectionFactory,
};
pub use crate::dns::core::{RecordClass, RecordType};
pub use crate::dns::message::{Flags, Message, MessageId, Operation, Question, Record, ResponseCode};
pub use crate::dns::name::Name;
Expand Down
10 changes: 6 additions & 4 deletions mtop-client/src/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@ use std::future::Future;
use std::net::SocketAddr;
use std::path::PathBuf;
use std::sync::Arc;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader, BufWriter};
use tokio::net::{TcpListener, TcpStream, ToSocketAddrs};
use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader, BufWriter};
use tokio::net::{TcpListener, ToSocketAddrs};
use tokio::runtime::Handle;
use tokio_rustls::rustls::server::WebPkiClientVerifier;
use tokio_rustls::rustls::ServerConfig;
use tokio_rustls::server::TlsStream;
use tokio_rustls::TlsAcceptor;

const RESPONSE_VERSION: &str = "VERSION 1.6.22\r\n";
Expand Down Expand Up @@ -68,7 +67,10 @@ where
})
}

async fn handle_client_connection(stream: TlsStream<TcpStream>) {
async fn handle_client_connection<S>(stream: S)
where
S: AsyncRead + AsyncWrite,
{
let (read, write) = tokio::io::split(stream);
let mut read = BufReader::new(read).lines();
let mut write = BufWriter::new(write);
Expand Down
4 changes: 3 additions & 1 deletion mtop/src/dns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ where
client_config.timeout = t;
}

DefaultDnsClient::new(client_config)
// Use default instances of the UDP and TCP connection factories, alternate
// implementations are only useful for unit testing.
DefaultDnsClient::new(client_config, Default::default(), Default::default())
}

async fn load_config<P>(resolv: P) -> Result<ResolvConf, MtopError>
Expand Down

0 comments on commit d378c60

Please sign in to comment.