Skip to content

Commit

Permalink
feature: cpu softmax
Browse files Browse the repository at this point in the history
  • Loading branch information
ivarflakstad committed Nov 22, 2024
1 parent 6d8fb10 commit 1dda703
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 13 deletions.
3 changes: 2 additions & 1 deletion crates/ratchet-core/src/cpu/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ pub mod gemm;
mod norm;
pub mod reindex;
pub mod rope;
mod softmax;
mod unary;
mod utils;

Expand All @@ -21,7 +22,7 @@ pub fn apply_operation(op: LazyOp, dst: Tensor) -> Result<Tensor, OperationError
LazyOp::Binary(b) => b.apply_cpu(dst),
LazyOp::Cast(c) => cpu_cast(c, dst),
LazyOp::Matmul(m) => m.apply_cpu(dst),
LazyOp::Softmax(_s) => todo!(),
LazyOp::Softmax(s) => s.apply_cpu(dst),
LazyOp::RoPE(r) => cpu_rope(r, dst),
LazyOp::Unary(u) => u.apply_cpu(dst),
LazyOp::Reindex(r) => r.apply_cpu(dst),
Expand Down
42 changes: 42 additions & 0 deletions crates/ratchet-core/src/cpu/softmax.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
use crate::cpu::utils::cpu_store_result;
use crate::{CPUOperation, DType, OperationError, Softmax, Tensor, TensorDType};
use half::{bf16, f16};
use num::Float;
use num_traits::NumAssignOps;

impl CPUOperation for Softmax {
fn apply_cpu(&self, dst: Tensor) -> Result<Tensor, OperationError> {
let Softmax { input, dim } = self;
match input.dt() {
DType::F32 => softmax::<f32>(input, *dim, &dst)?,
DType::F16 => softmax::<f16>(input, *dim, &dst)?,
DType::BF16 => softmax::<bf16>(input, *dim, &dst)?,
_ => todo!(),
}

Ok(dst)
}
}

fn softmax<T>(input: &Tensor, dim: usize, dst: &Tensor) -> Result<(), OperationError>
where
T: TensorDType + Float + NumAssignOps,
{
let src_shape = input.shape();
let mut input = input.to_vec::<T>()?;
let N = src_shape[dim];
input.chunks_mut(N).for_each(|chunk| {
let mut sum = T::zero();
for j in 0..N {
chunk[j] = chunk[j].exp();
sum += chunk[j];
}
for j in 0..N {
chunk[j] /= sum;
}
});

cpu_store_result(dst, &input);

Ok(())
}
27 changes: 15 additions & 12 deletions crates/ratchet-core/src/ops/softmax.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ use crate::{

#[derive(new, Debug, Clone)]
pub struct Softmax {
input: Tensor,
dim: usize,
pub(crate) input: Tensor,
pub(crate) dim: usize,
}

#[derive(Debug, derive_new::new, ShaderType, WgslMetadata)]
Expand Down Expand Up @@ -322,8 +322,7 @@ def softmax(a):
run_py_prg(prg.to_string(), &[a], &[], a.dt())
}

fn run_softmax_trial(problem: SoftmaxProblem) {
let device = Device::request_device(DeviceRequest::GPU).unwrap();
fn run_softmax_trial(problem: SoftmaxProblem, device: Device) {
let SoftmaxProblem { B, M, N } = problem;
let a = Tensor::randn::<f32>(shape![B, M, N], Device::CPU);
let ground = ground_truth(&a).unwrap();
Expand All @@ -332,8 +331,6 @@ def softmax(a):
let b = a_gpu.softmax(2).unwrap().resolve().unwrap();

let ours = b.to(&Device::CPU).unwrap();
println!("ours = {:?}", ours);
println!("ground = {:?}", ground);
ground.all_close(&ours, 1e-6, 1e-6).unwrap();
}

Expand All @@ -347,16 +344,22 @@ def softmax(a):
N: usize,
}

#[proptest(cases = 8)]
fn test_softmax(prob: SoftmaxProblem) {
let SoftmaxProblem { B, M, N } = prob;
println!("B = {}, M = {}, N = {}", B, M, N);
run_softmax_trial(prob);
#[proptest(cases = 18)]
fn test_softmax_gpu(prob: SoftmaxProblem) {
let device = Device::request_device(DeviceRequest::GPU).unwrap();
run_softmax_trial(prob, device);
}

#[proptest(cases = 16)]
fn test_softmax_cpu(prob: SoftmaxProblem) {
let device = Device::request_device(DeviceRequest::CPU).unwrap();
run_softmax_trial(prob, device);
}

#[test]
fn dbg_softmax() {
let device = Device::request_device(DeviceRequest::GPU).unwrap();
let problem = SoftmaxProblem { B: 1, M: 2, N: 128 };
run_softmax_trial(problem);
run_softmax_trial(problem, device);
}
}

0 comments on commit 1dda703

Please sign in to comment.