From 8963e881451bec094c60d9c7e89892aaaca71692 Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Mon, 12 Aug 2024 17:02:44 +0200 Subject: [PATCH] chore: begin unary. split unary gpu and cpu tests --- crates/ratchet-core/src/ops/unary.rs | 78 +++++++++++++++++++--------- crates/ratchet-core/src/tensor.rs | 2 +- 2 files changed, 54 insertions(+), 26 deletions(-) diff --git a/crates/ratchet-core/src/ops/unary.rs b/crates/ratchet-core/src/ops/unary.rs index 9233da80..f20e1c36 100644 --- a/crates/ratchet-core/src/ops/unary.rs +++ b/crates/ratchet-core/src/ops/unary.rs @@ -10,7 +10,7 @@ use strum_macros::EnumIter; use crate::{ gpu::{dtype::WgslDType, BindGroupLayoutDescriptor}, - rvec, Array, BindingMode, BuiltIn, DType, GPUOperation, Kernel, KernelElement, + rvec, Array, BindingMode, BuiltIn, CPUOperation, DType, GPUOperation, Kernel, KernelElement, KernelRenderable, KernelSource, OpGuards, Operation, OperationError, RVec, Scalar, StorageView, Tensor, Vec2, Vec4, WgslKernelBuilder, WgslPrimitive, WorkgroupSize, Workload, }; @@ -358,6 +358,30 @@ impl Kernel for UnaryKernels { } } +impl CPUOperation for Unary { + fn apply(&self, dst: Tensor) -> Result { + todo!() + /* + match self.op { + UnaryOp::Gelu => self.gelu(dst), + UnaryOp::Tanh => self.tanh(dst), + UnaryOp::Exp => self.exp(dst), + UnaryOp::Log => self.log(dst), + UnaryOp::Sin => self.sin(dst), + UnaryOp::Cos => self.cos(dst), + UnaryOp::Abs => self.abs(dst), + UnaryOp::Sqrt => self.sqrt(dst), + UnaryOp::Relu => self.relu(dst), + UnaryOp::Floor => self.floor(dst), + UnaryOp::Ceil => self.ceil(dst), + UnaryOp::Neg => self.neg(dst), + UnaryOp::Silu => self.silu(dst), + UnaryOp::Sigmoid => self.sigmoid(dst), + } + */ + } +} + #[cfg(all(test, feature = "pyo3"))] mod tests { use test_strategy::{proptest, Arbitrary}; @@ -404,10 +428,8 @@ def {}(a): run_py_prg(prg.to_string(), &[a], &[], a.dt()) } - fn run_unary_trial(prob: UnaryProblem) -> anyhow::Result<()> { - let device = Device::request_device(DeviceRequest::CPU).unwrap(); + fn run_unary_trial(prob: UnaryProblem, device: Device) -> anyhow::Result<()> { let UnaryProblem { op, B, M, N } = prob; - println!("op: {:?}, B: {}, M: {}, N: {}", op, B, M, N); let a = Tensor::randn::(shape![B, M], Device::CPU); let args = match op { @@ -416,23 +438,22 @@ def {}(a): }; let ground = ground_truth(&a, &op, args)?; - let a_gpu = a.to(&device)?; - let a_gpu = a; - let c_gpu = match op { - UnaryOp::Gelu => a_gpu.gelu()?, - UnaryOp::Tanh => a_gpu.tanh()?, - UnaryOp::Exp => a_gpu.exp()?, - UnaryOp::Log => a_gpu.log()?, - UnaryOp::Sin => a_gpu.sin()?, - UnaryOp::Cos => a_gpu.cos()?, - UnaryOp::Abs => a_gpu.abs()?, - UnaryOp::Sqrt => a_gpu.sqrt()?, - UnaryOp::Relu => a_gpu.relu()?, - UnaryOp::Floor => a_gpu.floor()?, - UnaryOp::Ceil => a_gpu.ceil()?, - UnaryOp::Neg => a_gpu.neg()?, - UnaryOp::Silu => a_gpu.silu()?, - UnaryOp::Sigmoid => a_gpu.sigmoid()?, + let a = a.to(&device)?; + let c = match op { + UnaryOp::Gelu => a.gelu()?, + UnaryOp::Tanh => a.tanh()?, + UnaryOp::Exp => a.exp()?, + UnaryOp::Log => a.log()?, + UnaryOp::Sin => a.sin()?, + UnaryOp::Cos => a.cos()?, + UnaryOp::Abs => a.abs()?, + UnaryOp::Sqrt => a.sqrt()?, + UnaryOp::Relu => a.relu()?, + UnaryOp::Floor => a.floor()?, + UnaryOp::Ceil => a.ceil()?, + UnaryOp::Neg => a.neg()?, + UnaryOp::Silu => a.silu()?, + UnaryOp::Sigmoid => a.sigmoid()?, } .resolve()?; @@ -441,13 +462,20 @@ def {}(a): _ => (1e-4, 1e-4), }; - let d_gpu = c_gpu.to(&Device::CPU)?; - ground.all_close(&d_gpu, atol, rtol)?; + let d = c.to(&Device::CPU)?; + ground.all_close(&d, atol, rtol)?; Ok(()) } #[proptest(cases = 256)] - fn test_unary(prob: UnaryProblem) { - run_unary_trial(prob).unwrap(); + fn test_unary_gpu(prob: UnaryProblem) { + let device = Device::request_device(DeviceRequest::GPU).unwrap(); + run_unary_trial(prob, device).unwrap(); + } + + #[proptest(cases = 256)] + fn test_unary_cpu(prob: UnaryProblem) { + let device = Device::request_device(DeviceRequest::CPU).unwrap(); + run_unary_trial(prob, device).unwrap(); } } diff --git a/crates/ratchet-core/src/tensor.rs b/crates/ratchet-core/src/tensor.rs index 36a425b3..371750a5 100644 --- a/crates/ratchet-core/src/tensor.rs +++ b/crates/ratchet-core/src/tensor.rs @@ -748,7 +748,7 @@ impl Tensor { LazyOp::Matmul(m) => todo!(), LazyOp::Softmax(s) => todo!(), LazyOp::RoPE(r) => todo!(), - LazyOp::Unary(u) => todo!(), + LazyOp::Unary(u) => u.apply(dst).ok(), LazyOp::Reindex(r) => todo!(), LazyOp::Concat(c) => todo!(), LazyOp::Norm(n) => todo!(),