Skip to content

Commit

Permalink
chore: begin unary. split unary gpu and cpu tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ivarflakstad committed Aug 12, 2024
1 parent 2f5b488 commit 8963e88
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 26 deletions.
78 changes: 53 additions & 25 deletions crates/ratchet-core/src/ops/unary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use strum_macros::EnumIter;

use crate::{
gpu::{dtype::WgslDType, BindGroupLayoutDescriptor},
rvec, Array, BindingMode, BuiltIn, DType, GPUOperation, Kernel, KernelElement,
rvec, Array, BindingMode, BuiltIn, CPUOperation, DType, GPUOperation, Kernel, KernelElement,
KernelRenderable, KernelSource, OpGuards, Operation, OperationError, RVec, Scalar, StorageView,
Tensor, Vec2, Vec4, WgslKernelBuilder, WgslPrimitive, WorkgroupSize, Workload,
};
Expand Down Expand Up @@ -358,6 +358,30 @@ impl Kernel for UnaryKernels {
}
}

impl CPUOperation for Unary {
fn apply(&self, dst: Tensor) -> Result<Tensor, OperationError> {
todo!()
/*
match self.op {
UnaryOp::Gelu => self.gelu(dst),
UnaryOp::Tanh => self.tanh(dst),
UnaryOp::Exp => self.exp(dst),
UnaryOp::Log => self.log(dst),
UnaryOp::Sin => self.sin(dst),
UnaryOp::Cos => self.cos(dst),
UnaryOp::Abs => self.abs(dst),
UnaryOp::Sqrt => self.sqrt(dst),
UnaryOp::Relu => self.relu(dst),
UnaryOp::Floor => self.floor(dst),
UnaryOp::Ceil => self.ceil(dst),
UnaryOp::Neg => self.neg(dst),
UnaryOp::Silu => self.silu(dst),
UnaryOp::Sigmoid => self.sigmoid(dst),
}
*/
}
}

#[cfg(all(test, feature = "pyo3"))]
mod tests {
use test_strategy::{proptest, Arbitrary};
Expand Down Expand Up @@ -404,10 +428,8 @@ def {}(a):
run_py_prg(prg.to_string(), &[a], &[], a.dt())
}

fn run_unary_trial(prob: UnaryProblem) -> anyhow::Result<()> {
let device = Device::request_device(DeviceRequest::CPU).unwrap();
fn run_unary_trial(prob: UnaryProblem, device: Device) -> anyhow::Result<()> {
let UnaryProblem { op, B, M, N } = prob;
println!("op: {:?}, B: {}, M: {}, N: {}", op, B, M, N);
let a = Tensor::randn::<f32>(shape![B, M], Device::CPU);

let args = match op {
Expand All @@ -416,23 +438,22 @@ def {}(a):
};
let ground = ground_truth(&a, &op, args)?;

let a_gpu = a.to(&device)?;
let a_gpu = a;
let c_gpu = match op {
UnaryOp::Gelu => a_gpu.gelu()?,
UnaryOp::Tanh => a_gpu.tanh()?,
UnaryOp::Exp => a_gpu.exp()?,
UnaryOp::Log => a_gpu.log()?,
UnaryOp::Sin => a_gpu.sin()?,
UnaryOp::Cos => a_gpu.cos()?,
UnaryOp::Abs => a_gpu.abs()?,
UnaryOp::Sqrt => a_gpu.sqrt()?,
UnaryOp::Relu => a_gpu.relu()?,
UnaryOp::Floor => a_gpu.floor()?,
UnaryOp::Ceil => a_gpu.ceil()?,
UnaryOp::Neg => a_gpu.neg()?,
UnaryOp::Silu => a_gpu.silu()?,
UnaryOp::Sigmoid => a_gpu.sigmoid()?,
let a = a.to(&device)?;
let c = match op {
UnaryOp::Gelu => a.gelu()?,
UnaryOp::Tanh => a.tanh()?,
UnaryOp::Exp => a.exp()?,
UnaryOp::Log => a.log()?,
UnaryOp::Sin => a.sin()?,
UnaryOp::Cos => a.cos()?,
UnaryOp::Abs => a.abs()?,
UnaryOp::Sqrt => a.sqrt()?,
UnaryOp::Relu => a.relu()?,
UnaryOp::Floor => a.floor()?,
UnaryOp::Ceil => a.ceil()?,
UnaryOp::Neg => a.neg()?,
UnaryOp::Silu => a.silu()?,
UnaryOp::Sigmoid => a.sigmoid()?,
}
.resolve()?;

Expand All @@ -441,13 +462,20 @@ def {}(a):
_ => (1e-4, 1e-4),
};

let d_gpu = c_gpu.to(&Device::CPU)?;
ground.all_close(&d_gpu, atol, rtol)?;
let d = c.to(&Device::CPU)?;
ground.all_close(&d, atol, rtol)?;
Ok(())
}

#[proptest(cases = 256)]
fn test_unary(prob: UnaryProblem) {
run_unary_trial(prob).unwrap();
fn test_unary_gpu(prob: UnaryProblem) {
let device = Device::request_device(DeviceRequest::GPU).unwrap();
run_unary_trial(prob, device).unwrap();
}

#[proptest(cases = 256)]
fn test_unary_cpu(prob: UnaryProblem) {
let device = Device::request_device(DeviceRequest::CPU).unwrap();
run_unary_trial(prob, device).unwrap();
}
}
2 changes: 1 addition & 1 deletion crates/ratchet-core/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -748,7 +748,7 @@ impl Tensor {
LazyOp::Matmul(m) => todo!(),
LazyOp::Softmax(s) => todo!(),
LazyOp::RoPE(r) => todo!(),
LazyOp::Unary(u) => todo!(),
LazyOp::Unary(u) => u.apply(dst).ok(),
LazyOp::Reindex(r) => todo!(),
LazyOp::Concat(c) => todo!(),
LazyOp::Norm(n) => todo!(),
Expand Down

0 comments on commit 8963e88

Please sign in to comment.