Skip to content

Commit

Permalink
chore: poor generation util
Browse files Browse the repository at this point in the history
  • Loading branch information
FL33TW00D committed Jan 28, 2024
2 parents 3d3bc8e + f0367c6 commit f155c7b
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 11 deletions.
88 changes: 78 additions & 10 deletions crates/ratchet-core/src/ops/matmul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use encase::ShaderType;
use crate::{
gpu::{BindGroupLayoutDescriptor, WorkgroupCount},
rvec, wgc, DType, Enforcer, InvariantError, KernelElement, OpMetadata, Operation,
OperationError, RVec, Shape, StorageView, Tensor,
OperationError, RVec, Shape, StorageView, Strides, Tensor,
};

// Defines a matrix multiplication operation.
Expand Down Expand Up @@ -185,6 +185,52 @@ pub struct Matmul {
rhs: Tensor,
}

impl Matmul {
pub fn compute_c_shape(a: &Tensor, b: &Tensor) -> anyhow::Result<Shape> {
let (mut ashape, mut bshape) = (a.shape().clone(), b.shape().clone());

let insert_one_if = |shape: &mut Shape, index: usize, condition: bool| {
if condition {
shape.insert(index, 1);
}
};
let arank = ashape.rank();
let brank = bshape.rank();
insert_one_if(&mut ashape, 0, arank < 2);
insert_one_if(&mut bshape, 1, brank < 2);

let equalize_rank = |shape: &mut Shape, target_rank: usize| {
while shape.rank() < target_rank {
shape.insert(0, 1);
}
};
equalize_rank(&mut ashape, bshape.rank());
equalize_rank(&mut bshape, ashape.rank());

let arank = ashape.rank();
let brank = bshape.rank();
let (a_prefix, b_prefix) = (&ashape[..arank - 2], &bshape[..brank - 2]);
let c_broadcasted_prefix =
Shape::multi_broadcast(&[a_prefix.into(), b_prefix.into()]).unwrap();

let (m, ka) = (ashape[arank - 2], ashape[arank - 1]);
let (kb, n) = (bshape[brank - 2], bshape[brank - 1]);
if ka != kb {
anyhow::bail!("Matmul broadcasting: a: {:?} b: {:?}", ashape, bshape);
}

let mut c_shape_final = c_broadcasted_prefix;
if ashape.rank() >= 2 {
c_shape_final.push(m);
}
if bshape.rank() >= 2 {
c_shape_final.push(n);
}

Ok(c_shape_final)
}
}

#[allow(clippy::too_many_arguments)]
#[derive(Debug, Clone, ShaderType)]
pub struct MatmulMeta {
Expand Down Expand Up @@ -239,11 +285,10 @@ impl Operation for Matmul {
}

fn infer_output(&self, srcs: &[&Tensor]) -> Result<StorageView, OperationError> {
let (_a, _b) = (srcs[0], srcs[1]);
//let c_shape = Matmul::compute_output_shape(a.clone(), b.clone()).unwrap();

//TODO: THIS IS WRONG 🚨
Ok(srcs[0].view().clone())
let (a, b) = (srcs[0], srcs[1]);
let c_shape = Matmul::compute_c_shape(a, b).unwrap();
let c_strides = Strides::from(&c_shape);
Ok(StorageView::new(c_shape, a.dt(), c_strides))
}

fn check_invariants(srcs: &[&Tensor]) -> Result<(), OperationError> {
Expand Down Expand Up @@ -304,6 +349,8 @@ impl Operation for Matmul {

#[cfg(test)]
mod tests {
use test_strategy::{proptest, Arbitrary};

use crate::test_util::run_py_prg;

use crate::{shape, Device, DeviceRequest, Quantization, Quantizer};
Expand All @@ -325,12 +372,33 @@ def matmul(a, b):
run_py_prg(prg.to_string(), &[a, b])
}

#[test]
fn test_sgemm() -> anyhow::Result<()> {
let (a, b) = matmul_harness()?;
#[derive(Arbitrary, Debug)]
struct SGEMMProblem {
#[strategy(1..=4usize)]
B: usize,
#[strategy(1..=512usize)]
M: usize,
#[strategy(1..=512usize)]
K: usize,
#[strategy(1..=512usize)]
N: usize,
}

#[proptest(cases = 8)]
fn test_sgemm(prob: SGEMMProblem) {
let device = Device::request_device(DeviceRequest::GPU).unwrap();
let SGEMMProblem { B, M, K, N } = prob;
println!("Running sgemm: B={} M={} K={} N={}", B, M, K, N);
run_matmul_trial(&device, prob).unwrap();
}

fn run_matmul_trial(device: &Device, prob: SGEMMProblem) -> anyhow::Result<()> {
let cpu_device = Device::request_device(DeviceRequest::CPU)?;
let SGEMMProblem { B, M, K, N } = prob;
let a = Tensor::randn::<f32>(shape![B, M, K], cpu_device.clone());
let b = Tensor::randn::<f32>(shape![B, K, N], cpu_device.clone());
let ground = ground_truth(&a, &b)?;

let device = Device::request_device(DeviceRequest::GPU)?;
let a_gpu = a.to(&device)?;
let b_gpu = b.to(&device)?;
let c_gpu = a_gpu.matmul(&b_gpu)?;
Expand Down
32 changes: 31 additions & 1 deletion crates/ratchet-core/src/shape.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::RVec;
use crate::{shape, RVec};
use encase::impl_wrapper;
use std::ops::RangeTo;

Expand Down Expand Up @@ -44,6 +44,10 @@ impl Shape {
self.len()
}

pub fn push(&mut self, dim: usize) {
self.0.push(dim);
}

#[inline]
pub fn left_pad_to(&mut self, scalar: usize, rank: usize) {
while self.0.len() < rank {
Expand All @@ -68,6 +72,26 @@ impl Shape {
pub fn slice(&self, range: std::ops::Range<usize>) -> Self {
Shape(self.0[range].to_vec().into())
}

pub fn multi_broadcast(shapes: &[Shape]) -> Option<Shape> {
let max_rank = shapes.iter().map(|shape| shape.rank()).max()?;
let mut shape: Shape = shape![];
for i in 0..max_rank {
let mut current_dim_size = 1;
for shape in shapes {
let len = shape.rank();
let dim = if i < len { &shape[len - i - 1] } else { &1 };
if dim != &1 {
if current_dim_size != 1 && dim != &current_dim_size {
return None;
}
current_dim_size = *dim;
}
}
shape.0.insert(0, current_dim_size)
}
Some(shape)
}
}

impl std::fmt::Debug for Shape {
Expand Down Expand Up @@ -119,3 +143,9 @@ impl From<Vec<u32>> for Shape {
Self(shape.into_iter().map(|x| x as usize).collect())
}
}

impl From<&[usize]> for Shape {
fn from(slice: &[usize]) -> Self {
Shape(slice.into())
}
}

0 comments on commit f155c7b

Please sign in to comment.