Skip to content

Commit

Permalink
Merge pull request #188 from 56quarters/dns-client-cleanup
Browse files Browse the repository at this point in the history
Make DefaultDnsClient generic for easier testing
  • Loading branch information
56quarters authored Sep 7, 2024
2 parents efc9725 + d378c60 commit f074f77
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 81 deletions.
155 changes: 84 additions & 71 deletions mtop-client/src/dns/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,17 +81,25 @@ pub trait DnsClient {
/// operation. This means that a single call to `resolve` make take longer
/// than the timeout since failed network operations are retried.
#[derive(Debug)]
pub struct DefaultDnsClient {
pub struct DefaultDnsClient<U = UdpConnectionFactory, T = TcpConnectionFactory>
where
U: ClientFactory<SocketAddr, UdpConnection> + Send + Sync + 'static,
T: ClientFactory<SocketAddr, TcpConnection> + Send + Sync + 'static,
{
config: DnsClientConfig,
server_idx: AtomicUsize,
udp_pool: ClientPool<SocketAddr, UdpClient, UdpFactory>,
tcp_pool: ClientPool<SocketAddr, TcpClient, TcpFactory>,
udp_pool: ClientPool<SocketAddr, UdpConnection, U>,
tcp_pool: ClientPool<SocketAddr, TcpConnection, T>,
}

impl DefaultDnsClient {
impl<U, T> DefaultDnsClient<U, T>
where
U: ClientFactory<SocketAddr, UdpConnection> + Send + Sync + 'static,
T: ClientFactory<SocketAddr, TcpConnection> + Send + Sync + 'static,
{
/// Create a new DnsClient that will resolve names using UDP or TCP connections
/// and behavior based on a resolv.conf configuration file.
pub fn new(config: DnsClientConfig) -> Self {
pub fn new(config: DnsClientConfig, udp_factory: U, tcp_factory: T) -> Self {
let udp_config = ClientPoolConfig {
name: "dns-udp".to_owned(),
max_idle: config.pool_max_idle,
Expand All @@ -105,10 +113,11 @@ impl DefaultDnsClient {
Self {
config,
server_idx: AtomicUsize::new(0),
udp_pool: ClientPool::new(udp_config, UdpFactory),
tcp_pool: ClientPool::new(tcp_config, TcpFactory),
udp_pool: ClientPool::new(udp_config, udp_factory),
tcp_pool: ClientPool::new(tcp_config, tcp_factory),
}
}

async fn exchange(&self, msg: &Message, attempt: usize) -> Result<Message, MtopError> {
let server = self.nameserver(attempt);

Expand Down Expand Up @@ -155,7 +164,11 @@ impl DefaultDnsClient {
}
}

impl DnsClient for DefaultDnsClient {
impl<U, T> DnsClient for DefaultDnsClient<U, T>
where
U: ClientFactory<SocketAddr, UdpConnection> + Send + Sync + 'static,
T: ClientFactory<SocketAddr, TcpConnection> + Send + Sync + 'static,
{
async fn resolve(&self, name: Name, rtype: RecordType, rclass: RecordClass) -> Result<Message, MtopError> {
let full = name.to_fqdn();
let id = MessageId::random();
Expand All @@ -180,48 +193,45 @@ impl DnsClient for DefaultDnsClient {
}
}

/// Client for sending and receiving DNS messages over read and write streams,
/// usually a TCP connection. Messages are sent with a two byte prefix that
/// indicates the size of the message. Responses are expected to have the same
/// prefix. The message ID of responses is checked to ensure it matches the request
/// ID. If it does not, an error is returned.
struct TcpClient {
/// Connection for unconditionally sending and receiving DNS messages using TCP streams.
/// Messages are sent with a two byte prefix that indicates the size of the message.
/// Responses are expected to have the same prefix. The message ID of responses is
/// checked to ensure it matches the request ID. If it does not, an error is returned.
pub struct TcpConnection {
read: BufReader<Box<dyn AsyncRead + Send + Sync + Unpin>>,
write: BufWriter<Box<dyn AsyncWrite + Send + Sync + Unpin>>,
size: usize,
buffer: Vec<u8>,
}

impl TcpClient {
fn new<R, W>(read: R, write: W, size: usize) -> Self
impl TcpConnection {
pub fn new<R, W>(read: R, write: W, size: usize) -> Self
where
R: AsyncRead + Unpin + Sync + Send + 'static,
W: AsyncWrite + Unpin + Sync + Send + 'static,
{
Self {
read: BufReader::new(Box::new(read)),
write: BufWriter::new(Box::new(write)),
size,
buffer: Vec::with_capacity(size),
}
}

async fn exchange(&mut self, msg: &Message) -> Result<Message, MtopError> {
let mut buf = Vec::with_capacity(self.size);

pub async fn exchange(&mut self, msg: &Message) -> Result<Message, MtopError> {
// Write the message to a local buffer and then send it, prefixed
// with the size of the message.
msg.write_network_bytes(&mut buf)?;
self.write.write_u16(buf.len() as u16).await?;
self.write.write_all(&buf).await?;
msg.write_network_bytes(&mut self.buffer)?;
self.write.write_u16(self.buffer.len() as u16).await?;
self.write.write_all(&self.buffer).await?;
self.write.flush().await?;

// Read the prefixed size of the response in big-endian (network)
// order and then read exactly that many bytes into our buffer.
let sz = self.read.read_u16().await?;
buf.clear();
buf.resize(usize::from(sz), 0);
self.read.read_exact(&mut buf).await?;
self.buffer.clear();
self.buffer.resize(usize::from(sz), 0);
self.read.read_exact(&mut self.buffer).await?;

let mut cur = Cursor::new(buf);
let mut cur = Cursor::new(&self.buffer);
let res = Message::read_network_bytes(&mut cur)?;
if res.id() != msg.id() {
Err(MtopError::runtime(format!(
Expand All @@ -235,56 +245,59 @@ impl TcpClient {
}
}

impl fmt::Debug for TcpClient {
impl fmt::Debug for TcpConnection {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(f, "TcpClient {{ read: ..., write: ..., size: {} }}", self.size)
write!(f, "TcpConnection {{ ... }}")
}
}

/// Client for sending and receiving DNS messages over read and write streams,
/// usually an adapter over a UDP socket. The message ID of responses is checked
/// to ensure it matches the request ID. If it does not, the response is discarded
/// and the client will wait for another response until it gets one with a matching
/// ID.
struct UdpClient {
/// Connection for unconditionally sending and receiving DNS messages using UDP packets.
/// The message ID of responses is checked to ensure it matches the request ID. If it
/// does not, the response is discarded and the client will wait for another response
/// until it gets one with a matching ID.
pub struct UdpConnection {
read: Box<dyn AsyncRead + Send + Sync + Unpin>,
write: Box<dyn AsyncWrite + Send + Sync + Unpin>,
size: usize,
buffer: Vec<u8>,
packet: usize,
}

impl UdpClient {
fn new<R, W>(read: R, write: W, size: usize) -> Self
impl UdpConnection {
pub fn new<R, W>(read: R, write: W, size: usize) -> Self
where
R: AsyncRead + Unpin + Sync + Send + 'static,
W: AsyncWrite + Unpin + Sync + Send + 'static,
{
Self {
read: Box::new(read),
write: Box::new(write),
size,
buffer: Vec::with_capacity(size),
packet: size,
}
}

async fn exchange(&mut self, msg: &Message) -> Result<Message, MtopError> {
let mut buf = Vec::with_capacity(self.size);
msg.write_network_bytes(&mut buf)?;
pub async fn exchange(&mut self, msg: &Message) -> Result<Message, MtopError> {
self.buffer.clear();
msg.write_network_bytes(&mut self.buffer)?;
// We expect this to be a datagram socket so we only do a single write.
let n = self.write.write(&buf).await?;
if n != buf.len() {
let n = self.write.write(&self.buffer).await?;
if n != self.buffer.len() {
return Err(MtopError::runtime(format!(
"short write to UDP socket. expected {}, got {}",
buf.len(),
self.buffer.len(),
n
)));
}
self.write.flush().await?;

buf.clear();
buf.resize(self.size, 0);
// Resize to our packet size since the .read() call will only read up to
// the size of the buffer at most.
self.buffer.clear();
self.buffer.resize(self.packet, 0);

loop {
let n = self.read.read(&mut buf).await?;
let cur = Cursor::new(&buf[0..n]);
let n = self.read.read(&mut self.buffer).await?;
let cur = Cursor::new(&self.buffer[0..n]);
let res = Message::read_network_bytes(cur)?;
if res.id() == msg.id() {
return Ok(res);
Expand All @@ -293,9 +306,9 @@ impl UdpClient {
}
}

impl fmt::Debug for UdpClient {
impl fmt::Debug for UdpConnection {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(f, "UdpClient {{ read: ..., write: ..., size: {} }}", self.size)
write!(f, "UdpConnection {{ ... }}")
}
}

Expand Down Expand Up @@ -330,38 +343,38 @@ impl AsyncWrite for SocketAdapter {
}
}

/// Implementation of `ClientFactory` for creating concrete `UdpClient` instances
/// that use UDP sockets.
/// Implementation of `ClientFactory` for creating concrete `UdpConnection` instances
/// that use a UDP socket.
#[derive(Debug, Clone, Default)]
struct UdpFactory;
pub struct UdpConnectionFactory;

impl ClientFactory<SocketAddr, UdpClient> for UdpFactory {
async fn make(&self, address: &SocketAddr) -> Result<UdpClient, MtopError> {
impl ClientFactory<SocketAddr, UdpConnection> for UdpConnectionFactory {
async fn make(&self, address: &SocketAddr) -> Result<UdpConnection, MtopError> {
let local = if address.is_ipv4() { "0.0.0.0:0" } else { "[::]:0" };
let sock = UdpSocket::bind(local).await?;
sock.connect(address).await?;

let adapter = SocketAdapter::new(sock);
let (read, write) = tokio::io::split(adapter);
Ok(UdpClient::new(read, write, DEFAULT_MESSAGE_BUFFER))
Ok(UdpConnection::new(read, write, DEFAULT_MESSAGE_BUFFER))
}
}

/// Implementation of `ClientFactory` for creating concrete `TcpClient` instances
/// that uses a split TCP socket.
/// Implementation of `ClientFactory` for creating concrete `TcpConnection` instances
/// that use a TCP socket.
#[derive(Debug, Clone, Default)]
struct TcpFactory;
pub struct TcpConnectionFactory;

impl ClientFactory<SocketAddr, TcpClient> for TcpFactory {
async fn make(&self, address: &SocketAddr) -> Result<TcpClient, MtopError> {
impl ClientFactory<SocketAddr, TcpConnection> for TcpConnectionFactory {
async fn make(&self, address: &SocketAddr) -> Result<TcpConnection, MtopError> {
let (read, write) = tcp_connect(address).await?;
Ok(TcpClient::new(read, write, DEFAULT_MESSAGE_BUFFER))
Ok(TcpConnection::new(read, write, DEFAULT_MESSAGE_BUFFER))
}
}

#[cfg(test)]
mod test {
use super::{TcpClient, UdpClient};
use super::{TcpConnection, UdpConnection};
use crate::core::ErrorKind;
use crate::dns::core::{RecordClass, RecordType};
use crate::dns::message::{Flags, Message, MessageId, Question, Record};
Expand Down Expand Up @@ -432,7 +445,7 @@ mod test {
let message =
Message::new(MessageId::from(123), Flags::default().set_recursion_desired()).add_question(question);

let mut client = TcpClient::new(read, write, 512);
let mut client = TcpConnection::new(read, write, 512);
let res = client.exchange(&message).await;
let err = res.unwrap_err();

Expand All @@ -450,7 +463,7 @@ mod test {
let message =
Message::new(MessageId::from(123), Flags::default().set_recursion_desired()).add_question(question);

let mut client = TcpClient::new(read, write, 512);
let mut client = TcpConnection::new(read, write, 512);
let res = client.exchange(&message).await;
let err = res.unwrap_err();

Expand All @@ -466,7 +479,7 @@ mod test {
let message =
Message::new(MessageId::from(123), Flags::default().set_recursion_desired()).add_question(question);

let mut client = TcpClient::new(read, write, 512);
let mut client = TcpConnection::new(read, write, 512);
let res = client.exchange(&message).await;
let err = res.unwrap_err();

Expand All @@ -482,7 +495,7 @@ mod test {
let message =
Message::new(MessageId::from(123), Flags::default().set_recursion_desired()).add_question(question);

let mut client = TcpClient::new(read, write, 512);
let mut client = TcpConnection::new(read, write, 512);
let res = client.exchange(&message).await.unwrap();

assert_eq!(message.id(), res.id());
Expand Down Expand Up @@ -542,7 +555,7 @@ mod test {
let message =
Message::new(MessageId::from(123), Flags::default().set_recursion_desired()).add_question(question);

let mut client = UdpClient::new(read, write, 512);
let mut client = UdpConnection::new(read, write, 512);
let res = client.exchange(&message).await.unwrap();

assert_eq!(message.id(), res.id());
Expand Down Expand Up @@ -570,7 +583,7 @@ mod test {
let message =
Message::new(MessageId::from(123), Flags::default().set_recursion_desired()).add_question(question);

let mut client = UdpClient::new(read, write, 512);
let mut client = UdpConnection::new(read, write, 512);
let res = client.exchange(&message).await.unwrap();

assert_eq!(message.id(), res.id());
Expand Down
8 changes: 4 additions & 4 deletions mtop-client/src/dns/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,10 @@ impl Message {
}

fn header(&self) -> Header {
assert!(self.questions.len() < u16::MAX as usize);
assert!(self.answers.len() < u16::MAX as usize);
assert!(self.authority.len() < u16::MAX as usize);
assert!(self.extra.len() < u16::MAX as usize);
assert!(self.questions.len() < usize::from(u16::MAX));
assert!(self.answers.len() < usize::from(u16::MAX));
assert!(self.authority.len() < usize::from(u16::MAX));
assert!(self.extra.len() < usize::from(u16::MAX));

Header {
id: self.id,
Expand Down
5 changes: 4 additions & 1 deletion mtop-client/src/dns/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@ mod name;
mod rdata;
mod resolv;

pub use crate::dns::client::{DefaultDnsClient, DnsClient, DnsClientConfig};
pub use crate::dns::client::{
DefaultDnsClient, DnsClient, DnsClientConfig, TcpConnection, TcpConnectionFactory, UdpConnection,
UdpConnectionFactory,
};
pub use crate::dns::core::{RecordClass, RecordType};
pub use crate::dns::message::{Flags, Message, MessageId, Operation, Question, Record, ResponseCode};
pub use crate::dns::name::Name;
Expand Down
10 changes: 6 additions & 4 deletions mtop-client/src/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@ use std::future::Future;
use std::net::SocketAddr;
use std::path::PathBuf;
use std::sync::Arc;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader, BufWriter};
use tokio::net::{TcpListener, TcpStream, ToSocketAddrs};
use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader, BufWriter};
use tokio::net::{TcpListener, ToSocketAddrs};
use tokio::runtime::Handle;
use tokio_rustls::rustls::server::WebPkiClientVerifier;
use tokio_rustls::rustls::ServerConfig;
use tokio_rustls::server::TlsStream;
use tokio_rustls::TlsAcceptor;

const RESPONSE_VERSION: &str = "VERSION 1.6.22\r\n";
Expand Down Expand Up @@ -68,7 +67,10 @@ where
})
}

async fn handle_client_connection(stream: TlsStream<TcpStream>) {
async fn handle_client_connection<S>(stream: S)
where
S: AsyncRead + AsyncWrite,
{
let (read, write) = tokio::io::split(stream);
let mut read = BufReader::new(read).lines();
let mut write = BufWriter::new(write);
Expand Down
4 changes: 3 additions & 1 deletion mtop/src/dns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ where
client_config.timeout = t;
}

DefaultDnsClient::new(client_config)
// Use default instances of the UDP and TCP connection factories, alternate
// implementations are only useful for unit testing.
DefaultDnsClient::new(client_config, Default::default(), Default::default())
}

async fn load_config<P>(resolv: P) -> Result<ResolvConf, MtopError>
Expand Down

0 comments on commit f074f77

Please sign in to comment.