diff --git a/crates/ratchet-core/src/cpu/mod.rs b/crates/ratchet-core/src/cpu/mod.rs index d07af976..ae9ecb7b 100644 --- a/crates/ratchet-core/src/cpu/mod.rs +++ b/crates/ratchet-core/src/cpu/mod.rs @@ -3,6 +3,7 @@ pub mod gemm; mod norm; pub mod reindex; pub mod rope; +mod softmax; mod unary; mod utils; @@ -21,7 +22,7 @@ pub fn apply_operation(op: LazyOp, dst: Tensor) -> Result b.apply_cpu(dst), LazyOp::Cast(c) => cpu_cast(c, dst), LazyOp::Matmul(m) => m.apply_cpu(dst), - LazyOp::Softmax(_s) => todo!(), + LazyOp::Softmax(s) => s.apply_cpu(dst), LazyOp::RoPE(r) => cpu_rope(r, dst), LazyOp::Unary(u) => u.apply_cpu(dst), LazyOp::Reindex(r) => r.apply_cpu(dst), diff --git a/crates/ratchet-core/src/cpu/softmax.rs b/crates/ratchet-core/src/cpu/softmax.rs new file mode 100644 index 00000000..1d6a3df0 --- /dev/null +++ b/crates/ratchet-core/src/cpu/softmax.rs @@ -0,0 +1,42 @@ +use crate::cpu::utils::cpu_store_result; +use crate::{CPUOperation, DType, OperationError, Softmax, Tensor, TensorDType}; +use half::{bf16, f16}; +use num::Float; +use num_traits::NumAssignOps; + +impl CPUOperation for Softmax { + fn apply_cpu(&self, dst: Tensor) -> Result { + let Softmax { input, dim } = self; + match input.dt() { + DType::F32 => softmax::(input, *dim, &dst)?, + DType::F16 => softmax::(input, *dim, &dst)?, + DType::BF16 => softmax::(input, *dim, &dst)?, + _ => todo!(), + } + + Ok(dst) + } +} + +fn softmax(input: &Tensor, dim: usize, dst: &Tensor) -> Result<(), OperationError> +where + T: TensorDType + Float + NumAssignOps, +{ + let src_shape = input.shape(); + let mut input = input.to_vec::()?; + let N = src_shape[dim]; + input.chunks_mut(N).for_each(|chunk| { + let mut sum = T::zero(); + for j in 0..N { + chunk[j] = chunk[j].exp(); + sum += chunk[j]; + } + for j in 0..N { + chunk[j] /= sum; + } + }); + + cpu_store_result(dst, &input); + + Ok(()) +} diff --git a/crates/ratchet-core/src/ops/softmax.rs b/crates/ratchet-core/src/ops/softmax.rs index 28e19987..68d4824b 100644 --- a/crates/ratchet-core/src/ops/softmax.rs +++ b/crates/ratchet-core/src/ops/softmax.rs @@ -13,8 +13,8 @@ use crate::{ #[derive(new, Debug, Clone)] pub struct Softmax { - input: Tensor, - dim: usize, + pub(crate) input: Tensor, + pub(crate) dim: usize, } #[derive(Debug, derive_new::new, ShaderType, WgslMetadata)] @@ -322,8 +322,7 @@ def softmax(a): run_py_prg(prg.to_string(), &[a], &[], a.dt()) } - fn run_softmax_trial(problem: SoftmaxProblem) { - let device = Device::request_device(DeviceRequest::GPU).unwrap(); + fn run_softmax_trial(problem: SoftmaxProblem, device: Device) { let SoftmaxProblem { B, M, N } = problem; let a = Tensor::randn::(shape![B, M, N], Device::CPU); let ground = ground_truth(&a).unwrap(); @@ -332,8 +331,6 @@ def softmax(a): let b = a_gpu.softmax(2).unwrap().resolve().unwrap(); let ours = b.to(&Device::CPU).unwrap(); - println!("ours = {:?}", ours); - println!("ground = {:?}", ground); ground.all_close(&ours, 1e-6, 1e-6).unwrap(); } @@ -347,16 +344,22 @@ def softmax(a): N: usize, } - #[proptest(cases = 8)] - fn test_softmax(prob: SoftmaxProblem) { - let SoftmaxProblem { B, M, N } = prob; - println!("B = {}, M = {}, N = {}", B, M, N); - run_softmax_trial(prob); + #[proptest(cases = 18)] + fn test_softmax_gpu(prob: SoftmaxProblem) { + let device = Device::request_device(DeviceRequest::GPU).unwrap(); + run_softmax_trial(prob, device); + } + + #[proptest(cases = 16)] + fn test_softmax_cpu(prob: SoftmaxProblem) { + let device = Device::request_device(DeviceRequest::CPU).unwrap(); + run_softmax_trial(prob, device); } #[test] fn dbg_softmax() { + let device = Device::request_device(DeviceRequest::GPU).unwrap(); let problem = SoftmaxProblem { B: 1, M: 2, N: 128 }; - run_softmax_trial(problem); + run_softmax_trial(problem, device); } }