Skip to content

Commit

Permalink
chore: testing them all
Browse files Browse the repository at this point in the history
  • Loading branch information
FL33TW00D committed Jan 28, 2024
1 parent f155c7b commit c8b4acd
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 11 deletions.
1 change: 1 addition & 0 deletions crates/ratchet-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ pub mod test_util {
};

/// It's a bit of a hack, but it's useful for testing.
/// if a function name is not provided, looks for the first function in the program.
#[cfg(feature = "pyo3")]
pub fn run_py_prg(prg: String, args: &[&Tensor]) -> anyhow::Result<Tensor> {
let re = Regex::new(r"def\s+(\w+)\s*\(").unwrap();
Expand Down
80 changes: 80 additions & 0 deletions crates/ratchet-core/src/ops/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,3 +114,83 @@ impl Operation for Binary {
Ok(BinaryMeta { M, N })
}
}

#[cfg(test)]
mod tests {
use test_strategy::{proptest, Arbitrary};

use crate::{shape, test_util::run_py_prg, Device, DeviceRequest, Tensor};

#[derive(Arbitrary, Debug)]
enum TestBinOp {
Add,
Sub,
Mul,
Div,
}

impl TestBinOp {
fn to_str(&self) -> &'static str {
match self {
TestBinOp::Add => "add",
TestBinOp::Sub => "sub",
TestBinOp::Mul => "mul",
TestBinOp::Div => "div",
}
}
}

#[derive(Arbitrary, Debug)]
struct BinaryProblem {
op: TestBinOp,
#[strategy(1..=4usize)]
B: usize,
#[strategy(1..=512usize)]
M: usize,
//#[strategy(1..=512usize)]
//N: usize,
//TODO: add N support
}

fn ground_truth(a: &Tensor, b: &Tensor, op: &TestBinOp) -> anyhow::Result<Tensor> {
let prg = format!(
r#"
import torch
def {}(a, b):
return torch.{}(torch.from_numpy(a), torch.from_numpy(b)).numpy()
"#,
op.to_str(),
op.to_str()
);
run_py_prg(prg.to_string(), &[a, b])
}

fn run_binary_trial(device: &Device, prob: BinaryProblem) -> anyhow::Result<()> {
let cpu_device = Device::request_device(DeviceRequest::CPU)?;
let BinaryProblem { op, B, M } = prob;
println!("op: {:?}, B: {}, M: {}", op, B, M);
let a = Tensor::randn::<f32>(shape![B, M], cpu_device.clone());
let b = Tensor::randn::<f32>(shape![1], cpu_device.clone());
let ground = ground_truth(&a, &b, &op)?;

let a_gpu = a.to(&device)?;
let b_gpu = b.to(&device)?;
let c_gpu = match op {
TestBinOp::Add => a_gpu.add(&b_gpu)?,
TestBinOp::Sub => a_gpu.sub(&b_gpu)?,
TestBinOp::Mul => a_gpu.mul(&b_gpu)?,
TestBinOp::Div => a_gpu.div(&b_gpu)?,
};
c_gpu.resolve()?;

let d_gpu = c_gpu.to(&Device::CPU)?;
ground.all_close(&d_gpu, 1e-5, 1e-5)?;
Ok(())
}

#[proptest(cases = 32)]
fn test_sgemm(prob: BinaryProblem) {
let device = Device::request_device(DeviceRequest::GPU).unwrap();
run_binary_trial(&device, prob).unwrap();
}
}
31 changes: 20 additions & 11 deletions crates/ratchet-core/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,18 +168,27 @@ impl Tensor {
}
}

impl Tensor {
pub fn add(&self, other: &Tensor) -> anyhow::Result<Tensor> {
Binary::check_invariants(&[self, other])?;
macro_rules! impl_binary_op {
($method_name:ident, $op:expr) => {
pub fn $method_name(&self, other: &Tensor) -> anyhow::Result<Tensor> {
Binary::check_invariants(&[self, other])?;

let binary = Binary::new(self.clone(), other.clone(), $op);
let new_view = binary.infer_output(&[self, other])?;
Ok(Tensor::lazy(
LazyOp::Binary(binary),
new_view,
self.device.clone(),
))
}
};
}

let binary = Binary::new(self.clone(), other.clone(), BinaryOp::Add);
let new_view = binary.infer_output(&[self, other])?;
Ok(Tensor::lazy(
LazyOp::Binary(binary),
new_view,
self.device.clone(),
))
}
impl Tensor {
impl_binary_op!(add, BinaryOp::Add);
impl_binary_op!(sub, BinaryOp::Sub);
impl_binary_op!(mul, BinaryOp::Mul);
impl_binary_op!(div, BinaryOp::Div);

//TODO: switch dim to isize and allow negative indexing
pub fn softmax(&self, dim: usize) -> anyhow::Result<Tensor> {
Expand Down

0 comments on commit c8b4acd

Please sign in to comment.