Skip to content

Commit

Permalink
Merge pull request #149 from 56quarters/ping
Browse files Browse the repository at this point in the history
Add a `ping` command to the `dns` binary
  • Loading branch information
56quarters authored Jun 24, 2024
2 parents edd90cf + 068162f commit 79319dc
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 26 deletions.
10 changes: 5 additions & 5 deletions mtop-client/src/dns/name.rs
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ impl FromStr for Name {
}

if s.len() > Self::MAX_LENGTH {
return Err(MtopError::runtime(format!(
return Err(MtopError::configuration(format!(
"Name too long; max {} bytes, got {}",
Self::MAX_LENGTH,
s
Expand All @@ -240,7 +240,7 @@ impl FromStr for Name {
for label in s.trim_end_matches('.').split('.') {
let len = label.len();
if len > Self::MAX_LABEL_LENGTH {
return Err(MtopError::runtime(format!(
return Err(MtopError::configuration(format!(
"Name label too long; max {} bytes, got {}",
Self::MAX_LABEL_LENGTH,
label
Expand All @@ -251,17 +251,17 @@ impl FromStr for Name {

for (i, c) in label.char_indices() {
if i == 0 && !c.is_ascii_alphanumeric() && c != '_' {
return Err(MtopError::runtime(format!(
return Err(MtopError::configuration(format!(
"Name label must begin with ASCII letter, number, or underscore; got {}",
label
)));
} else if i == len - 1 && !c.is_ascii_alphanumeric() {
return Err(MtopError::runtime(format!(
return Err(MtopError::configuration(format!(
"Name label must end with ASCII letter or number; got {}",
label
)));
} else if !c.is_ascii_alphanumeric() && c != '-' && c != '_' {
return Err(MtopError::runtime(format!(
return Err(MtopError::configuration(format!(
"Name label must be ASCII letter, number, hyphen, or underscore; got {}",
label
)));
Expand Down
130 changes: 109 additions & 21 deletions mtop/src/bin/dns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,18 @@ use std::fmt::Write;
use std::io::Cursor;
use std::path::PathBuf;
use std::process::ExitCode;
use std::str::FromStr;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::{task, time};
use tracing::{Instrument, Level};

const DEFAULT_LOG_LEVEL: Level = Level::INFO;
const DEFAULT_RECORD_TYPE: RecordType = RecordType::A;
const DEFAULT_RECORD_CLASS: RecordClass = RecordClass::INET;
const DEFAULT_PING_INTERVAL_SECS: f64 = 1.0;
const MIN_PING_INTERVAL_SECS: f64 = 0.01;

/// dns: Make DNS queries or read/write binary format DNS messages
#[derive(Debug, Parser)]
Expand All @@ -33,11 +38,49 @@ struct DnsConfig {

#[derive(Debug, Subcommand)]
enum Action {
Ping(PingCommand),
Query(QueryCommand),
Read(ReadCommand),
Write(WriteCommand),
}

/// Repeatedly perform a DNS query and display the time taken as ping-like text output.
#[derive(Debug, Args)]
struct PingCommand {
/// How often to run queries, in seconds. Fractional seconds are allowed.
#[arg(long, value_parser = parse_interval, default_value_t = DEFAULT_PING_INTERVAL_SECS)]
interval_secs: f64,

/// Stop after performing `count` queries. Default is to run until interrupted.
#[arg(long, default_value_t = 0)]
count: u64,

/// Path to resolv.conf file for loading DNS configuration information. If this file
/// can't be loaded, default values for DNS configuration are used instead.
#[arg(long, default_value = default_resolv_conf().into_os_string(), value_hint = ValueHint::FilePath)]
resolv_conf: PathBuf,

/// Type of record to request. Supported: A, AAAA, CNAME, NS, SOA, SRV, TXT.
#[arg(long, default_value_t = DEFAULT_RECORD_TYPE)]
rtype: RecordType,

/// Class of record to request. Supported: INET, CHAOS, HESIOD, NONE, ANY.
#[arg(long, default_value_t = DEFAULT_RECORD_CLASS)]
rclass: RecordClass,

/// Domain name to lookup.
#[arg(required = true)]
name: Name,
}

fn parse_interval(s: &str) -> Result<f64, String> {
match s.parse() {
Ok(v) if v >= MIN_PING_INTERVAL_SECS => Ok(v),
Ok(_) => Err(format!("must be at least {}", MIN_PING_INTERVAL_SECS)),
Err(e) => Err(e.to_string()),
}
}

/// Perform a DNS query and display the result as dig-like text output.
#[derive(Debug, Args)]
struct QueryCommand {
Expand All @@ -62,7 +105,7 @@ struct QueryCommand {

/// Domain name to lookup.
#[arg(required = true)]
name: String,
name: Name,
}

fn default_resolv_conf() -> PathBuf {
Expand All @@ -86,7 +129,7 @@ struct WriteCommand {

/// Domain name to lookup.
#[arg(required = true)]
name: String,
name: Name,
}

#[tokio::main]
Expand All @@ -99,6 +142,7 @@ async fn main() -> ExitCode {

let profiling = profile::Writer::default();
let code = match &opts.mode {
Action::Ping(cmd) => run_ping(cmd).await,
Action::Query(cmd) => run_query(cmd).await,
Action::Read(cmd) => run_read(cmd).await,
Action::Write(cmd) => run_write(cmd).await,
Expand All @@ -111,26 +155,78 @@ async fn main() -> ExitCode {
code
}

async fn run_query(cmd: &QueryCommand) -> ExitCode {
async fn run_ping(cmd: &PingCommand) -> ExitCode {
let client = mtop::dns::new_client(&cmd.resolv_conf)
.instrument(tracing::span!(Level::INFO, "dns.new_client"))
.await;
let name = match Name::from_str(&cmd.name) {
Ok(n) => n,
Err(e) => {
tracing::error!(message = "invalid name supplied", name = cmd.name, err = %e);
return ExitCode::FAILURE;

// This command runs until interrupted, so we need to handle SIGINT
// to stop gracefully.
let run = Arc::new(AtomicBool::new(true));
let run_ref = run.clone();
task::spawn(async move {
tokio::select! {
_ = tokio::signal::ctrl_c() => {
run_ref.store(false, Ordering::Release);
}
}
};
});

let mut interval = time::interval(Duration::from_secs_f64(cmd.interval_secs));
let mut iterations = 0;

while run.load(Ordering::Acquire) && (cmd.count == 0 || iterations < cmd.count) {
let _ = interval.tick().await;
// Create our own Instant to measure the time taken to perform the query since
// the one emitted by the interval isn't _immediately_ when the future resolves
// and so skews the measurement of queries.
let start = Instant::now();

match client
.resolve(cmd.name.clone(), cmd.rtype, cmd.rclass)
.instrument(tracing::span!(Level::INFO, "client.resolve"))
.await
{
Ok(r) => {
tracing::info!(
id = %r.id(),
name = %cmd.name,
response_code = ?r.flags().get_response_code(),
num_questions = r.questions().len(),
num_answers = r.answers().len(),
num_authority = r.authority().len(),
num_extra = r.extra().len(),
elapsed = ?start.elapsed(),
);
}
Err(e) => {
tracing::error!(message = "failed to resolve", name = %cmd.name, err = %e);
}
}

iterations += 1;
}

if !run.load(Ordering::Acquire) {
tracing::info!("stopping on SIGINT");
}

ExitCode::SUCCESS
}

async fn run_query(cmd: &QueryCommand) -> ExitCode {
let client = mtop::dns::new_client(&cmd.resolv_conf)
.instrument(tracing::span!(Level::INFO, "dns.new_client"))
.await;

let response = match client
.resolve(name, cmd.rtype, cmd.rclass)
.resolve(cmd.name.clone(), cmd.rtype, cmd.rclass)
.instrument(tracing::span!(Level::INFO, "client.resolve"))
.await
{
Ok(r) => r,
Err(e) => {
tracing::error!(message = "unable to perform DNS query", err = %e);
tracing::error!(message = "unable to perform DNS query", name = %cmd.name, err = %e);
return ExitCode::FAILURE;
}
};
Expand Down Expand Up @@ -168,17 +264,9 @@ async fn run_read(_: &ReadCommand) -> ExitCode {
}

async fn run_write(cmd: &WriteCommand) -> ExitCode {
let name = match Name::from_str(&cmd.name) {
Ok(n) => n,
Err(e) => {
tracing::error!(message = "invalid name supplied", name = cmd.name, err = %e);
return ExitCode::FAILURE;
}
};

let id = MessageId::random();
let msg = Message::new(id, Flags::default().set_query().set_recursion_desired())
.add_question(Question::new(name, cmd.rtype).set_qclass(cmd.rclass));
.add_question(Question::new(cmd.name.clone(), cmd.rtype).set_qclass(cmd.rclass));

write_binary_message(&msg).await
}
Expand Down

0 comments on commit 79319dc

Please sign in to comment.