Skip to content

Commit

Permalink
Merge pull request #189 from 56quarters/tcp-dns
Browse files Browse the repository at this point in the history
dns: Fix an issue where a buffer was incorrectly reused
  • Loading branch information
56quarters authored Oct 26, 2024
2 parents f074f77 + b6563b7 commit 0201bfd
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 2 deletions.
71 changes: 71 additions & 0 deletions mtop-client/src/dns/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,8 @@ pub struct TcpConnection {
read: BufReader<Box<dyn AsyncRead + Send + Sync + Unpin>>,
write: BufWriter<Box<dyn AsyncWrite + Send + Sync + Unpin>>,
buffer: Vec<u8>,
bytes_read: AtomicUsize,
bytes_written: AtomicUsize,
}

impl TcpConnection {
Expand All @@ -213,24 +215,33 @@ impl TcpConnection {
read: BufReader::new(Box::new(read)),
write: BufWriter::new(Box::new(write)),
buffer: Vec::with_capacity(size),
bytes_read: AtomicUsize::new(0),
bytes_written: AtomicUsize::new(0),
}
}

pub async fn exchange(&mut self, msg: &Message) -> Result<Message, MtopError> {
// Write the message to a local buffer and then send it, prefixed
// with the size of the message.
self.buffer.clear();
msg.write_network_bytes(&mut self.buffer)?;
self.write.write_u16(self.buffer.len() as u16).await?;
self.write.write_all(&self.buffer).await?;
self.write.flush().await?;

// Increment total bytes written including the request size prefix.
self.bytes_written.fetch_add(self.buffer.len() + 2, Ordering::Relaxed);

// Read the prefixed size of the response in big-endian (network)
// order and then read exactly that many bytes into our buffer.
let sz = self.read.read_u16().await?;
self.buffer.clear();
self.buffer.resize(usize::from(sz), 0);
self.read.read_exact(&mut self.buffer).await?;

// Increment total bytes read including the response size prefix.
self.bytes_read.fetch_add(self.buffer.len() + 2, Ordering::Relaxed);

let mut cur = Cursor::new(&self.buffer);
let res = Message::read_network_bytes(&mut cur)?;
if res.id() != msg.id() {
Expand All @@ -243,6 +254,14 @@ impl TcpConnection {
Ok(res)
}
}

pub fn bytes_written(&self) -> usize {
self.bytes_written.load(Ordering::Relaxed)
}

pub fn bytes_read(&self) -> usize {
self.bytes_read.load(Ordering::Relaxed)
}
}

impl fmt::Debug for TcpConnection {
Expand Down Expand Up @@ -486,6 +505,58 @@ mod test {
assert_eq!(ErrorKind::Runtime, err.kind());
}

#[tokio::test]
async fn test_tcp_client_multiple_messages() {
let write = Vec::new();
let read = {
let response1 = new_message_bytes(123, true);
let response2 = new_message_bytes(456, true);
let mut bytes = Vec::new();
bytes.extend(response1);
bytes.extend(response2);
Cursor::new(bytes)
};

let mut client = TcpConnection::new(read, write, 512);

let question = Question::new(Name::from_str("example.com.").unwrap(), RecordType::A);
let message1 =
Message::new(MessageId::from(123), Flags::default().set_recursion_desired()).add_question(question.clone());
let message2 =
Message::new(MessageId::from(456), Flags::default().set_recursion_desired()).add_question(question.clone());

let res1 = client.exchange(&message1).await.unwrap();
assert_eq!(message1.id(), res1.id());
assert_eq!(message1.questions()[0], res1.questions()[0]);
assert_eq!(
Record::new(
Name::from_str("example.com.").unwrap(),
RecordType::A,
RecordClass::INET,
60,
RecordData::A(RecordDataA::new(Ipv4Addr::new(127, 0, 0, 100))),
),
res1.answers()[0]
);

let res2 = client.exchange(&message2).await.unwrap();
assert_eq!(message2.id(), res2.id());
assert_eq!(message2.questions()[0], res2.questions()[0]);
assert_eq!(
Record::new(
Name::from_str("example.com.").unwrap(),
RecordType::A,
RecordClass::INET,
60,
RecordData::A(RecordDataA::new(Ipv4Addr::new(127, 0, 0, 100))),
),
res2.answers()[0]
);

let expected_bytes = message1.size() + message2.size() + 2 + 2;
assert_eq!(expected_bytes, client.bytes_written());
}

#[tokio::test]
async fn test_tcp_client_success() {
let write = Vec::new();
Expand Down
12 changes: 12 additions & 0 deletions mtop-client/src/dns/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@ pub enum RecordType {
Unknown(u16),
}

impl RecordType {
pub fn size(&self) -> usize {
2
}
}

impl From<u16> for RecordType {
fn from(value: u16) -> Self {
match value {
Expand Down Expand Up @@ -95,6 +101,12 @@ pub enum RecordClass {
Unknown(u16),
}

impl RecordClass {
pub fn size(&self) -> usize {
2
}
}

impl From<u16> for RecordClass {
fn from(value: u16) -> Self {
match value {
Expand Down
42 changes: 40 additions & 2 deletions mtop-client/src/dns/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ impl MessageId {
pub fn random() -> Self {
Self(rand::random())
}

pub fn size(&self) -> usize {
2
}
}

impl From<u16> for MessageId {
Expand Down Expand Up @@ -57,6 +61,16 @@ impl Message {
}
}

pub fn size(&self) -> usize {
self.id.size()
+ self.flags.size()
+ (2 * 4) // lengths of questions, answers, authority, extra
+ self.questions.iter().map(|q| q.size()).sum::<usize>()
+ self.answers.iter().map(|r| r.size()).sum::<usize>()
+ self.authority.iter().map(|r| r.size()).sum::<usize>()
+ self.extra.iter().map(|r| r.size()).sum::<usize>()
}

pub fn id(&self) -> MessageId {
self.id
}
Expand Down Expand Up @@ -246,6 +260,10 @@ impl Flags {
const OFFSET_RA: usize = 7;
const OFFSET_RC: usize = 0;

pub fn size(&self) -> usize {
2
}

pub fn is_query(&self) -> bool {
!(self.0 & Self::MASK_QR) > 0
}
Expand Down Expand Up @@ -445,6 +463,10 @@ impl Question {
}
}

pub fn size(&self) -> usize {
self.name.size() + self.qtype.size() + self.qclass.size()
}

pub fn set_qclass(mut self, qclass: RecordClass) -> Self {
self.qclass = qclass;
self
Expand Down Expand Up @@ -502,6 +524,15 @@ impl Record {
}
}

pub fn size(&self) -> usize {
self.name.size()
+ self.rtype.size()
+ self.rclass.size()
+ 4 // ttl
+ 2 // rdata length
+ self.rdata.size()
}

pub fn name(&self) -> &Name {
&self.name
}
Expand Down Expand Up @@ -875,6 +906,7 @@ mod test {
#[test]
fn test_question_write_network_bytes() {
let q = Question::new(Name::from_str("example.com.").unwrap(), RecordType::AAAA);
let size = q.size();
let mut cur = Cursor::new(Vec::new());
q.write_network_bytes(&mut cur).unwrap();
let buf = cur.into_inner();
Expand All @@ -891,6 +923,7 @@ mod test {
],
buf,
);
assert_eq!(size, buf.len());
}

#[rustfmt::skip]
Expand All @@ -906,10 +939,12 @@ mod test {
0, 1, // INET class
]);

let size = cur.get_ref().len();
let q = Question::read_network_bytes(cur).unwrap();
assert_eq!("example.com.", q.name().to_string());
assert_eq!(RecordType::AAAA, q.qtype());
assert_eq!(RecordClass::INET, q.qclass());
assert_eq!(size, q.size());
}

#[rustfmt::skip]
Expand All @@ -922,6 +957,7 @@ mod test {
300,
RecordData::A(RecordDataA::new(Ipv4Addr::new(127, 0, 0, 100))),
);
let size = rr.size();
let mut cur = Cursor::new(Vec::new());
rr.write_network_bytes(&mut cur).unwrap();
let buf = cur.into_inner();
Expand All @@ -942,7 +978,8 @@ mod test {
127, 0, 0, 100, // rdata, A address
],
buf,
)
);
assert_eq!(size, buf.len());
}

#[rustfmt::skip]
Expand All @@ -963,16 +1000,17 @@ mod test {
127, 0, 0, 100, // rdata, A address
]);

let size = cur.get_ref().len();
let rr = Record::read_network_bytes(cur).unwrap();
assert_eq!("www.example.com.", rr.name().to_string());
assert_eq!(RecordType::A, rr.rtype());
assert_eq!(RecordClass::INET, rr.rclass());
assert_eq!(300, rr.ttl());

if let RecordData::A(rd) = rr.rdata() {
assert_eq!(Ipv4Addr::new(127, 0, 0, 100), rd.addr());
} else {
panic!("unexpected rdata type: {:?}", rr.rdata());
}
assert_eq!(size, rr.size());
}
}

0 comments on commit 0201bfd

Please sign in to comment.