diff --git a/crates/rsjudge-judger/src/comparer/compare.rs b/crates/rsjudge-judger/src/comparer/compare.rs deleted file mode 100644 index d86ce01..0000000 --- a/crates/rsjudge-judger/src/comparer/compare.rs +++ /dev/null @@ -1,73 +0,0 @@ -use std::{ - future::Future, - io, - marker::PhantomPinned, - pin::Pin, - task::{ready, Context, Poll}, -}; - -use pin_project::pin_project; -use tokio::io::AsyncRead; - -use super::AsyncComparer; -use crate::CompareResult; - -#[derive(Debug)] -#[must_use = "futures do nothing unless you `.await` or poll them"] -#[pin_project] -pub struct Compare<'a, C: ?Sized, Out, Ans> { - comparer: &'a mut C, - out: Out, - ans: Ans, - - // Make this future `!Unpin` for compatibility with async trait methods. - #[pin] - _pin: PhantomPinned, -} - -pub(super) fn compare<'a, C, Out, Ans>( - comparer: &'a mut C, - out: Out, - ans: Ans, -) -> Compare<'a, C, Out, Ans> -where - C: AsyncComparer + Unpin + ?Sized, - Out: AsyncRead + Unpin, - Ans: AsyncRead + Unpin, -{ - Compare { - comparer, - out, - ans, - _pin: PhantomPinned, - } -} - -fn compare_internal( - mut comparer: Pin<&mut C>, - cx: &mut Context, - out: Out, - ans: Ans, -) -> Poll> -where - C: AsyncComparer + Unpin + ?Sized, - Out: AsyncRead + Unpin, - Ans: AsyncRead + Unpin, -{ - Poll::Ready(ready!(comparer.as_mut().poll_compare(cx, out, ans))) -} - -impl<'c, C, Out, Ans> Future for Compare<'c, C, Out, Ans> -where - C: AsyncComparer + Unpin + ?Sized, - Out: AsyncRead + Unpin, - Ans: AsyncRead + Unpin, -{ - type Output = io::Result; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let me = self.project(); - - compare_internal(Pin::new(*me.comparer), cx, me.out, me.ans) - } -} diff --git a/crates/rsjudge-judger/src/comparer/default_comparer.rs b/crates/rsjudge-judger/src/comparer/default_comparer.rs new file mode 100644 index 0000000..6941ac5 --- /dev/null +++ b/crates/rsjudge-judger/src/comparer/default_comparer.rs @@ -0,0 +1,180 @@ +//! A default comparer implementation, supporting ignoring trailing whitespace and/or trailing newline. + +use async_trait::async_trait; +use tokio::{ + io::{self, AsyncBufReadExt as _, AsyncRead, BufReader}, + join, +}; +use tokio_stream::{wrappers::SplitStream, StreamExt as _}; + +use crate::{CompareResult, Comparer}; + +pub struct DefaultComparer { + ignore_trailing_whitespace: bool, + ignore_trailing_newline: bool, +} + +impl DefaultComparer { + pub fn new(ignore_trailing_whitespace: bool, ignore_trailing_newline: bool) -> Self { + Self { + ignore_trailing_whitespace, + ignore_trailing_newline, + } + } + fn compare_line(&self, out_line: impl AsRef<[u8]>, ans_line: impl AsRef<[u8]>) -> bool { + fn trim_end(line: &[u8]) -> &[u8] { + if line.is_empty() { + line + } else { + let end = line + .iter() + .rposition(|c| !c.is_ascii_whitespace()) + .unwrap_or_else(|| line.len() - 1); + &line[..=end] + } + } + let out_line = out_line.as_ref(); + let ans_line = ans_line.as_ref(); + let (out_line, ans_line) = if self.ignore_trailing_whitespace { + (trim_end(out_line), trim_end(ans_line)) + } else { + (out_line, ans_line) + }; + out_line == ans_line + } +} + +impl Default for DefaultComparer { + fn default() -> Self { + Self::new(true, true) + } +} + +#[async_trait] +impl Comparer for DefaultComparer { + async fn compare(&self, out: Out, ans: Ans) -> io::Result + where + Out: AsyncRead + Send + Unpin, + Ans: AsyncRead + Send + Unpin, + { + let out = BufReader::new(out); + let ans = BufReader::new(ans); + let mut out_lines = SplitStream::new(out.split(b'\n')).fuse(); + let mut ans_lines = SplitStream::new(ans.split(b'\n')).fuse(); + loop { + match join!(out_lines.next(), ans_lines.next()) { + (Some(out_line), Some(ans_line)) => { + if !self.compare_line(&out_line?, &ans_line?) { + return Ok(CompareResult::WrongAnswer); + } + } + (Some(out_line), None) => { + if !self.ignore_trailing_newline || !self.compare_line(&out_line?, []) { + return Ok(CompareResult::WrongAnswer); + } + } + (None, Some(ans_line)) => { + if !self.ignore_trailing_newline || !self.compare_line([], &ans_line?) { + return Ok(CompareResult::WrongAnswer); + } + } + (None, None) => return Ok(CompareResult::Accepted), + } + } + } +} + +#[cfg(test)] +mod tests { + use std::io; + + use temp_dir::TempDir; + use tokio::{ + fs::File, + io::{empty, AsyncWriteExt as _}, + }; + + use super::{CompareResult, DefaultComparer}; + use crate::Comparer as _; + + #[tokio::test] + async fn compare_empty() -> io::Result<()> { + let comparer = DefaultComparer::default(); + let out = empty(); + let ans = empty(); + let result = comparer.compare(out, ans).await?; + assert_eq!(result, CompareResult::Accepted); + Ok(()) + } + + #[tokio::test] + async fn compare_files() -> io::Result<()> { + let temp_dir = TempDir::new()?; + + let out_path = temp_dir.path().join("out"); + let ans_path = temp_dir.path().join("ans"); + + { + File::create(&out_path) + .await? + .write_all(b"Hello, World!\n") + .await?; + File::create(&ans_path) + .await? + .write_all(b"Hello, World!\n") + .await?; + } + + { + let comparer = DefaultComparer::default(); + + let result = comparer + .compare(File::open(&out_path).await?, File::open(&ans_path).await?) + .await?; + assert_eq!(result, CompareResult::Accepted); + } + + Ok(()) + } + + #[tokio::test] + async fn compare_with_trailing_whitespace() -> io::Result<()> { + let out = b"Hello, World! \n"; + let ans = b"Hello, World!\n"; + let comparer = DefaultComparer::new(true, true); + let result = comparer.compare(&out[..], &ans[..]).await?; + assert_eq!(result, CompareResult::Accepted); + Ok(()) + } + + #[tokio::test] + async fn compare_with_invalid_utf8() -> io::Result<()> { + let out = b"Hello, World! \xFF\n"; + let ans = b"Hello, World!\n"; + let comparer = DefaultComparer::new(true, true); + let result = comparer.compare(&out[..], &ans[..]).await?; + assert_eq!(dbg!(result), CompareResult::WrongAnswer); + + Ok(()) + } + + #[tokio::test] + async fn compare_with_trailing_newline() -> io::Result<()> { + let out = b"Hello, World!\n"; + let ans = b"Hello, World!"; + let comparer = DefaultComparer::new(true, true); + let result = comparer.compare(&out[..], &ans[..]).await?; + assert_eq!(result, CompareResult::Accepted); + Ok(()) + } + + #[tokio::test] + async fn compare_with_trailing_content_after_newline() -> io::Result<()> { + let out = b"Hello, World!\naaa\n"; + let ans = b"Hello, World!"; + let comparer = DefaultComparer::new(true, true); + let result = comparer.compare(&out[..], &ans[..]).await?; + assert_eq!(result, CompareResult::WrongAnswer); + Ok(()) + } +} diff --git a/crates/rsjudge-judger/src/comparer/mod.rs b/crates/rsjudge-judger/src/comparer/mod.rs index 0fd0b33..2c59631 100644 --- a/crates/rsjudge-judger/src/comparer/mod.rs +++ b/crates/rsjudge-judger/src/comparer/mod.rs @@ -1,14 +1,8 @@ -pub mod compare; -use std::{ - pin::Pin, - task::{Context, Poll}, -}; +pub mod default_comparer; use async_trait::async_trait; use tokio::io::{self, AsyncRead}; -use self::compare::{compare, Compare}; - #[derive(Debug, PartialEq)] pub enum CompareResult { Accepted, @@ -24,24 +18,3 @@ pub trait Comparer { Out: AsyncRead + Send + Unpin, Ans: AsyncRead + Send + Unpin; } - -pub trait AsyncComparer { - fn poll_compare( - self: Pin<&mut Self>, - cx: &mut Context, - out: Out, - ans: Ans, - ) -> Poll> - where - Out: AsyncRead + Unpin, - Ans: AsyncRead + Unpin; - - fn compare<'a, Out, Ans>(&'a mut self, out: Out, ans: Ans) -> Compare<'a, Self, Out, Ans> - where - Self: Unpin, - Out: AsyncRead + Send + Unpin, - Ans: AsyncRead + Send + Unpin, - { - compare(self, out, ans) - } -} diff --git a/crates/rsjudge-judger/src/default_comparer.rs b/crates/rsjudge-judger/src/default_comparer.rs deleted file mode 100644 index 0a862b8..0000000 --- a/crates/rsjudge-judger/src/default_comparer.rs +++ /dev/null @@ -1,127 +0,0 @@ -use async_trait::async_trait; -use tokio::{ - io::{self, AsyncBufReadExt as _, AsyncRead, BufReader}, - join, -}; -use tokio_stream::{wrappers::LinesStream, StreamExt as _}; - -use crate::{CompareResult, Comparer}; - -pub struct DefaultComparer { - ignore_trailing_whitespace: bool, - ignore_trailing_newline: bool, -} - -impl DefaultComparer { - pub fn new(ignore_trailing_whitespace: bool, ignore_trailing_newline: bool) -> Self { - Self { - ignore_trailing_whitespace, - ignore_trailing_newline, - } - } - - fn compare_line(&self, out_line: &str, ans_line: &str) -> bool { - let (out_line, ans_line) = if self.ignore_trailing_whitespace { - (out_line.trim_end(), ans_line.trim_end()) - } else { - (out_line, ans_line) - }; - out_line == ans_line - } -} - -impl Default for DefaultComparer { - fn default() -> Self { - Self::new(true, true) - } -} - -#[async_trait] -impl Comparer for DefaultComparer { - async fn compare(&self, out: Out, ans: Ans) -> io::Result - where - Out: AsyncRead + Send + Unpin, - Ans: AsyncRead + Send + Unpin, - { - let out = BufReader::new(out); - let ans = BufReader::new(ans); - - let mut out_lines = LinesStream::new(out.lines()); - let mut ans_lines = LinesStream::new(ans.lines()); - - while let (Some(out_line), Some(ans_line)) = join!(out_lines.next(), ans_lines.next()) { - if !self.compare_line(&out_line?, &ans_line?) { - return Ok(CompareResult::WrongAnswer); - } - } - - if self.ignore_trailing_newline { - while let Some(out_line) = out_lines.next().await { - if !out_line?.trim().is_empty() { - return Ok(CompareResult::WrongAnswer); - } - } - - while let Some(ans_line) = ans_lines.next().await { - if !ans_line?.trim().is_empty() { - return Ok(CompareResult::WrongAnswer); - } - } - } - - Ok(CompareResult::Accepted) - } -} - -#[cfg(test)] -mod tests { - use std::io; - - use temp_dir::TempDir; - use tokio::{ - fs::File, - io::{empty, AsyncWriteExt as _}, - }; - - use super::{CompareResult, Comparer as _, DefaultComparer}; - - #[tokio::test] - async fn compare_empty() -> io::Result<()> { - let comparer = DefaultComparer::default(); - let out = empty(); - let ans = empty(); - let result = comparer.compare(out, ans).await?; - assert_eq!(result, CompareResult::Accepted); - Ok(()) - } - - #[tokio::test] - async fn compare_files() -> io::Result<()> { - let temp_dir = TempDir::new()?; - - let out_path = temp_dir.path().join("out"); - let ans_path = temp_dir.path().join("ans"); - - { - File::create(&out_path) - .await? - .write_all(b"Hello, World!\n") - .await?; - File::create(&ans_path) - .await? - .write_all(b"Hello, World!\n") - .await?; - } - - { - let comparer = DefaultComparer::default(); - - let result = comparer - .compare(File::open(&out_path).await?, File::open(&ans_path).await?) - .await?; - assert_eq!(result, CompareResult::Accepted); - } - - Ok(()) - } -} diff --git a/crates/rsjudge-judger/src/lib.rs b/crates/rsjudge-judger/src/lib.rs index 711483b..19e251e 100644 --- a/crates/rsjudge-judger/src/lib.rs +++ b/crates/rsjudge-judger/src/lib.rs @@ -1,4 +1,3 @@ pub mod comparer; -pub mod default_comparer; -pub use comparer::{CompareResult, Comparer}; +pub use comparer::{default_comparer::DefaultComparer, CompareResult, Comparer};