From dc0c4223db25f249fd5e7d84bf69591c6593f275 Mon Sep 17 00:00:00 2001 From: FL33TW00D Date: Sat, 27 Jan 2024 21:12:24 +0000 Subject: [PATCH 1/3] chore: sgemm tests work --- crates/ratchet-core/src/ops/matmul.rs | 88 ++++++++++++++++++++++++--- crates/ratchet-core/src/shape.rs | 32 +++++++++- 2 files changed, 109 insertions(+), 11 deletions(-) diff --git a/crates/ratchet-core/src/ops/matmul.rs b/crates/ratchet-core/src/ops/matmul.rs index 5be08a9c..ab2981cd 100644 --- a/crates/ratchet-core/src/ops/matmul.rs +++ b/crates/ratchet-core/src/ops/matmul.rs @@ -6,7 +6,7 @@ use encase::ShaderType; use crate::{ gpu::{BindGroupLayoutDescriptor, WorkgroupCount}, rvec, wgc, DType, Enforcer, InvariantError, KernelElement, OpMetadata, Operation, - OperationError, RVec, Shape, StorageView, Tensor, + OperationError, RVec, Shape, StorageView, Strides, Tensor, }; // Defines a matrix multiplication operation. @@ -185,6 +185,52 @@ pub struct Matmul { rhs: Tensor, } +impl Matmul { + pub fn compute_c_shape(a: &Tensor, b: &Tensor) -> anyhow::Result { + let (mut ashape, mut bshape) = (a.shape().clone(), b.shape().clone()); + + let insert_one_if = |shape: &mut Shape, index: usize, condition: bool| { + if condition { + shape.insert(index, 1); + } + }; + let arank = ashape.rank(); + let brank = bshape.rank(); + insert_one_if(&mut ashape, 0, arank < 2); + insert_one_if(&mut bshape, 1, brank < 2); + + let equalize_rank = |shape: &mut Shape, target_rank: usize| { + while shape.rank() < target_rank { + shape.insert(0, 1); + } + }; + equalize_rank(&mut ashape, bshape.rank()); + equalize_rank(&mut bshape, ashape.rank()); + + let arank = ashape.rank(); + let brank = bshape.rank(); + let (a_prefix, b_prefix) = (&ashape[..arank - 2], &bshape[..brank - 2]); + let c_broadcasted_prefix = + Shape::multi_broadcast(&[a_prefix.into(), b_prefix.into()]).unwrap(); + + let (m, ka) = (ashape[arank - 2], ashape[arank - 1]); + let (kb, n) = (bshape[brank - 2], bshape[brank - 1]); + if ka != kb { + anyhow::bail!("Matmul broadcasting: a: {:?} b: {:?}", ashape, bshape); + } + + let mut c_shape_final = c_broadcasted_prefix; + if ashape.rank() >= 2 { + c_shape_final.push(m); + } + if bshape.rank() >= 2 { + c_shape_final.push(n); + } + + Ok(c_shape_final) + } +} + #[allow(clippy::too_many_arguments)] #[derive(Debug, Clone, ShaderType)] pub struct MatmulMeta { @@ -239,11 +285,10 @@ impl Operation for Matmul { } fn infer_output(&self, srcs: &[&Tensor]) -> Result { - let (_a, _b) = (srcs[0], srcs[1]); - //let c_shape = Matmul::compute_output_shape(a.clone(), b.clone()).unwrap(); - - //TODO: THIS IS WRONG 🚨 - Ok(srcs[0].view().clone()) + let (a, b) = (srcs[0], srcs[1]); + let c_shape = Matmul::compute_c_shape(a, b).unwrap(); + let c_strides = Strides::from(&c_shape); + Ok(StorageView::new(c_shape, a.dt(), c_strides)) } fn check_invariants(srcs: &[&Tensor]) -> Result<(), OperationError> { @@ -304,6 +349,8 @@ impl Operation for Matmul { #[cfg(test)] mod tests { + use test_strategy::{proptest, Arbitrary}; + use crate::test_util::run_py_prg; use crate::{shape, Device, DeviceRequest, Quantization, Quantizer}; @@ -325,12 +372,33 @@ def matmul(a, b): run_py_prg(prg.to_string(), &[a, b]) } - #[test] - fn test_sgemm() -> anyhow::Result<()> { - let (a, b) = matmul_harness()?; + #[derive(Arbitrary, Debug)] + struct SGEMMProblem { + #[strategy(1..=4usize)] + B: usize, + #[strategy(1..=1024usize)] + M: usize, + #[strategy(1..=1024usize)] + K: usize, + #[strategy(1..=1024usize)] + N: usize, + } + + #[proptest(cases = 8)] + fn test_sgemm(prob: SGEMMProblem) { + let device = Device::request_device(DeviceRequest::GPU).unwrap(); + let SGEMMProblem { B, M, K, N } = prob; + println!("Running sgemm: B={} M={} K={} N={}", B, M, K, N); + run_matmul_trial(&device, prob).unwrap(); + } + + fn run_matmul_trial(device: &Device, prob: SGEMMProblem) -> anyhow::Result<()> { + let cpu_device = Device::request_device(DeviceRequest::CPU)?; + let SGEMMProblem { B, M, K, N } = prob; + let a = Tensor::randn::(shape![B, M, K], cpu_device.clone()); + let b = Tensor::randn::(shape![B, K, N], cpu_device.clone()); let ground = ground_truth(&a, &b)?; - let device = Device::request_device(DeviceRequest::GPU)?; let a_gpu = a.to(&device)?; let b_gpu = b.to(&device)?; let c_gpu = a_gpu.matmul(&b_gpu)?; diff --git a/crates/ratchet-core/src/shape.rs b/crates/ratchet-core/src/shape.rs index 9734c950..665ec912 100644 --- a/crates/ratchet-core/src/shape.rs +++ b/crates/ratchet-core/src/shape.rs @@ -1,4 +1,4 @@ -use crate::RVec; +use crate::{shape, RVec}; use encase::impl_wrapper; use std::ops::RangeTo; @@ -44,6 +44,10 @@ impl Shape { self.len() } + pub fn push(&mut self, dim: usize) { + self.0.push(dim); + } + #[inline] pub fn left_pad_to(&mut self, scalar: usize, rank: usize) { while self.0.len() < rank { @@ -68,6 +72,26 @@ impl Shape { pub fn slice(&self, range: std::ops::Range) -> Self { Shape(self.0[range].to_vec().into()) } + + pub fn multi_broadcast(shapes: &[Shape]) -> Option { + let max_rank = shapes.iter().map(|shape| shape.rank()).max()?; + let mut shape: Shape = shape![]; + for i in 0..max_rank { + let mut current_dim_size = 1; + for shape in shapes { + let len = shape.rank(); + let dim = if i < len { &shape[len - i - 1] } else { &1 }; + if dim != &1 { + if current_dim_size != 1 && dim != ¤t_dim_size { + return None; + } + current_dim_size = *dim; + } + } + shape.0.insert(0, current_dim_size) + } + Some(shape) + } } impl std::fmt::Debug for Shape { @@ -119,3 +143,9 @@ impl From> for Shape { Self(shape.into_iter().map(|x| x as usize).collect()) } } + +impl From<&[usize]> for Shape { + fn from(slice: &[usize]) -> Self { + Shape(slice.into()) + } +} From 2306c296fb234a3f089bffaced6858b37b9611ba Mon Sep 17 00:00:00 2001 From: FL33TW00D Date: Sat, 27 Jan 2024 22:00:41 +0000 Subject: [PATCH 2/3] chore: weaker tol --- crates/ratchet-core/src/ops/matmul.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/ratchet-core/src/ops/matmul.rs b/crates/ratchet-core/src/ops/matmul.rs index ab2981cd..69ac8b3f 100644 --- a/crates/ratchet-core/src/ops/matmul.rs +++ b/crates/ratchet-core/src/ops/matmul.rs @@ -405,7 +405,7 @@ def matmul(a, b): c_gpu.resolve()?; let d_gpu = c_gpu.to(&Device::CPU)?; - ground.all_close(&d_gpu, 1e-4, 1e-4)?; + ground.all_close(&d_gpu, 1e-3, 1e-3)?; Ok(()) } From 018f574fb30924861a5fa30b5cb31fcde5f29bf0 Mon Sep 17 00:00:00 2001 From: FL33TW00D Date: Sat, 27 Jan 2024 22:10:56 +0000 Subject: [PATCH 3/3] chore: weaker tol --- crates/ratchet-core/src/kernels.rs | 66 ++++++--------------------- crates/ratchet-core/src/ops/matmul.rs | 8 ++-- 2 files changed, 17 insertions(+), 57 deletions(-) diff --git a/crates/ratchet-core/src/kernels.rs b/crates/ratchet-core/src/kernels.rs index dc08801b..01218085 100644 --- a/crates/ratchet-core/src/kernels.rs +++ b/crates/ratchet-core/src/kernels.rs @@ -1,57 +1,17 @@ // This file is generated by build.rs. Do not edit it manually. -use lazy_static::lazy_static; use std::collections::HashMap; +use lazy_static::lazy_static; lazy_static! { - pub static ref KERNELS: HashMap<&'static str, &'static str> = { - let mut m = HashMap::new(); - m.insert( - "qgemm_vec4", - include_str!( - r"/Users/fleetwood/Code/ratchet/crates/ratchet-core/kernels/qgemm_vec4.wgsl" - ), - ); - m.insert( - "sgemm_scalar", - include_str!( - r"/Users/fleetwood/Code/ratchet/crates/ratchet-core/kernels/sgemm_scalar.wgsl" - ), - ); - m.insert( - "add_scalar", - include_str!( - r"/Users/fleetwood/Code/ratchet/crates/ratchet-core/kernels/add_scalar.wgsl" - ), - ); - m.insert( - "sgemm_vec2", - include_str!( - r"/Users/fleetwood/Code/ratchet/crates/ratchet-core/kernels/sgemm_vec2.wgsl" - ), - ); - m.insert( - "softmax_vec2", - include_str!( - r"/Users/fleetwood/Code/ratchet/crates/ratchet-core/kernels/softmax_vec2.wgsl" - ), - ); - m.insert( - "sgemm_vec4", - include_str!( - r"/Users/fleetwood/Code/ratchet/crates/ratchet-core/kernels/sgemm_vec4.wgsl" - ), - ); - m.insert( - "softmax_scalar", - include_str!( - r"/Users/fleetwood/Code/ratchet/crates/ratchet-core/kernels/softmax_scalar.wgsl" - ), - ); - m.insert( - "softmax_vec4", - include_str!( - r"/Users/fleetwood/Code/ratchet/crates/ratchet-core/kernels/softmax_vec4.wgsl" - ), - ); - m - }; +pub static ref KERNELS: HashMap<&'static str, &'static str> = { + let mut m = HashMap::new(); + m.insert("qgemm_vec4", include_str!(r"/Users/fleetwood/Code/ratchet/crates/ratchet-core/kernels/qgemm_vec4.wgsl")); + m.insert("sgemm_scalar", include_str!(r"/Users/fleetwood/Code/ratchet/crates/ratchet-core/kernels/sgemm_scalar.wgsl")); + m.insert("add_scalar", include_str!(r"/Users/fleetwood/Code/ratchet/crates/ratchet-core/kernels/add_scalar.wgsl")); + m.insert("sgemm_vec2", include_str!(r"/Users/fleetwood/Code/ratchet/crates/ratchet-core/kernels/sgemm_vec2.wgsl")); + m.insert("softmax_vec2", include_str!(r"/Users/fleetwood/Code/ratchet/crates/ratchet-core/kernels/softmax_vec2.wgsl")); + m.insert("sgemm_vec4", include_str!(r"/Users/fleetwood/Code/ratchet/crates/ratchet-core/kernels/sgemm_vec4.wgsl")); + m.insert("softmax_scalar", include_str!(r"/Users/fleetwood/Code/ratchet/crates/ratchet-core/kernels/softmax_scalar.wgsl")); + m.insert("softmax_vec4", include_str!(r"/Users/fleetwood/Code/ratchet/crates/ratchet-core/kernels/softmax_vec4.wgsl")); + m +}; } diff --git a/crates/ratchet-core/src/ops/matmul.rs b/crates/ratchet-core/src/ops/matmul.rs index 69ac8b3f..32f3b807 100644 --- a/crates/ratchet-core/src/ops/matmul.rs +++ b/crates/ratchet-core/src/ops/matmul.rs @@ -376,11 +376,11 @@ def matmul(a, b): struct SGEMMProblem { #[strategy(1..=4usize)] B: usize, - #[strategy(1..=1024usize)] + #[strategy(1..=512usize)] M: usize, - #[strategy(1..=1024usize)] + #[strategy(1..=512usize)] K: usize, - #[strategy(1..=1024usize)] + #[strategy(1..=512usize)] N: usize, } @@ -405,7 +405,7 @@ def matmul(a, b): c_gpu.resolve()?; let d_gpu = c_gpu.to(&Device::CPU)?; - ground.all_close(&d_gpu, 1e-3, 1e-3)?; + ground.all_close(&d_gpu, 1e-4, 1e-4)?; Ok(()) }