diff --git a/crates/ratchet-core/src/cpu/reindex.rs b/crates/ratchet-core/src/cpu/reindex.rs index 9cb79d3f..efa0f912 100644 --- a/crates/ratchet-core/src/cpu/reindex.rs +++ b/crates/ratchet-core/src/cpu/reindex.rs @@ -1,20 +1,133 @@ use super::utils::cpu_store_result; use crate::{ - Broadcast, CPUOperation, DType, OperationError, Reindex, Shape, Slice, Strides, Tensor, - TensorDType, + Broadcast, CPUOperation, DType, OperationError, Permute, Reindex, Shape, Slice, Strides, + Tensor, TensorDType, }; use half::{bf16, f16}; impl CPUOperation for Reindex { fn apply_cpu(&self, dst: Tensor) -> Result { match self { + Reindex::Permute(p) => p.apply_cpu(dst), Reindex::Slice(s) => s.apply_cpu(dst), Reindex::Broadcast(b) => b.apply_cpu(dst), + } + } +} + +impl CPUOperation for Permute { + fn apply_cpu(&self, dst: Tensor) -> Result { + match dst.dt() { + DType::F32 => apply_permute::(self, dst), + DType::BF16 => apply_permute::(self, dst), + DType::F16 => apply_permute::(self, dst), + DType::I32 => apply_permute::(self, dst), + DType::U32 => apply_permute::(self, dst), _ => todo!(), } } } +fn apply_permute(p: &Permute, dst: Tensor) -> Result { + let perm: [usize; 4] = p.promote().try_into().unwrap(); + let Permute { src, dims: _ } = p; + let result = permute(&src.to_vec::()?, src.shape(), dst.shape(), perm); + cpu_store_result(&dst, &result); + Ok(dst) +} + +// TODO: Optimize. +// This generic implementation is almost a direct copy from the gpu impl, +// and can definitely be way more performant. +fn permute( + src: &[T], + src_shape: &Shape, + dst_shape: &Shape, + perm: [usize; 4], +) -> Vec { + let mut result = vec![T::zero(); src_shape.numel()]; + + // We now know that these will always be len 4, same as gpu impl. + let src_shape = &Shape::promote(src_shape.clone(), 4); + let dst_shape = &Shape::promote(dst_shape.clone(), 4); + + let src_strides = &Strides::from(src_shape); + let dst_strides = &Strides::from(dst_shape); + + let src_shape: [usize; 4] = src_shape.try_into().unwrap(); + let src_strides: [usize; 4] = src_strides.try_into().unwrap(); + let dst_strides: [usize; 4] = dst_strides.try_into().unwrap(); + + for i in 0..result.len() { + let dst_index = offset_to_ndindex(i, dst_strides); + let mut src_index = [0; 4]; + src_index[perm[0]] = dst_index[0]; + src_index[perm[1]] = dst_index[1]; + src_index[perm[2]] = dst_index[2]; + src_index[perm[3]] = dst_index[3]; + let src_offset = nd_index_to_offset(src_index, src_strides); + result[i] = src[src_offset] + } + result +} + +impl CPUOperation for Slice { + fn apply_cpu(&self, dst: Tensor) -> Result { + match dst.dt() { + DType::F32 => apply_slice::(self, dst), + DType::BF16 => apply_slice::(self, dst), + DType::F16 => apply_slice::(self, dst), + DType::I32 => apply_slice::(self, dst), + DType::U32 => apply_slice::(self, dst), + _ => todo!(), + } + } +} + +fn apply_slice(s: &Slice, dst: Tensor) -> Result { + let (start, stop): (Vec<_>, Vec<_>) = s.indices().iter().map(|r| (r.start, r.end)).unzip(); + let result = slice(&s.src.to_vec::()?, s.src.strides(), &start, &stop); + + cpu_store_result(&dst, &result); + Ok(dst) +} + +pub(crate) fn slice( + src: &[T], + src_strides: &Strides, + start: &[usize], + stop: &[usize], +) -> Vec { + assert!(start.len() == stop.len()); + assert!(start.len() == src_strides.rank()); + start.iter().zip(stop.iter()).for_each(|(s, t)| { + assert!(s < t); + }); + + let dst_shape: Vec = stop.iter().zip(start.iter()).map(|(s, t)| s - t).collect(); + let dst_numel: usize = dst_shape.iter().product(); + + let mut dst = vec![T::zero(); dst_numel]; + + let mut dst_dots = vec![]; + for d in 0..dst_shape.len() { + dst_dots.push(dst_shape[d + 1..].iter().product::().max(1)); + } + + for i in 0..dst.len() { + let mut src_index = 0; + let mut tmp = i; + for d in 0..dst_shape.len() { + let coord = tmp / dst_dots[d]; + tmp %= dst_dots[d]; + src_index += (coord + start[d]) * src_strides[d] as usize; + } + dst[i] = src[src_index]; + } + + dst +} + impl CPUOperation for Broadcast { fn apply_cpu(&self, dst: Tensor) -> Result { match dst.dt() { @@ -85,26 +198,6 @@ fn generic_broadcast( let src_strides: [usize; 4] = src_strides.try_into().unwrap(); let dst_strides: [usize; 4] = dst_strides.try_into().unwrap(); - fn offset_to_ndindex(offset: usize, strides: [usize; 4]) -> [usize; 4] { - let mut indices = [0; 4]; - let mut remaining = offset; - - let idx = remaining / strides[0]; - indices[0] = idx; - remaining -= idx * strides[0]; - - let idx = remaining / strides[1]; - indices[1] = idx; - remaining -= idx * strides[1]; - - let idx = remaining / strides[2]; - indices[2] = idx; - remaining -= idx * strides[2]; - - indices[3] = remaining; - indices - } - fn select(a: [usize; 4], b: [usize; 4], t: [bool; 4]) -> [usize; 4] { let mut result = [0; 4]; result[0] = if t[0] { a[0] } else { b[0] }; @@ -114,13 +207,6 @@ fn generic_broadcast( result } - fn nd_index_to_offset(ndindex: [usize; 4], strides: [usize; 4]) -> usize { - ndindex[0] * strides[0] - + ndindex[1] * strides[1] - + ndindex[2] * strides[2] - + ndindex[3] * strides[3] - } - let shape_onedim_lookup: [bool; 4] = [ src_shape[0] != 1, src_shape[1] != 1, @@ -135,59 +221,31 @@ fn generic_broadcast( } } -impl CPUOperation for Slice { - fn apply_cpu(&self, dst: Tensor) -> Result { - match dst.dt() { - DType::F32 => apply_slice::(self, dst), - DType::BF16 => apply_slice::(self, dst), - DType::F16 => apply_slice::(self, dst), - DType::I32 => apply_slice::(self, dst), - DType::U32 => apply_slice::(self, dst), - _ => todo!(), - } - } -} - -fn apply_slice(s: &Slice, dst: Tensor) -> Result { - let (start, stop): (Vec<_>, Vec<_>) = s.indices().iter().map(|r| (r.start, r.end)).unzip(); - let result = slice(&s.src.to_vec::()?, s.src.strides(), &start, &stop); - - cpu_store_result(&dst, &result); - Ok(dst) -} - -pub(crate) fn slice( - src: &[T], - src_strides: &Strides, - start: &[usize], - stop: &[usize], -) -> Vec { - assert!(start.len() == stop.len()); - assert!(start.len() == src_strides.rank()); - start.iter().zip(stop.iter()).for_each(|(s, t)| { - assert!(s < t); - }); +#[inline] +fn offset_to_ndindex(offset: usize, strides: [usize; 4]) -> [usize; 4] { + let mut indices = [0; 4]; + let mut remaining = offset; - let dst_shape: Vec = stop.iter().zip(start.iter()).map(|(s, t)| s - t).collect(); - let dst_numel: usize = dst_shape.iter().product(); + let idx = remaining / strides[0]; + indices[0] = idx; + remaining -= idx * strides[0]; - let mut dst = vec![T::zero(); dst_numel]; + let idx = remaining / strides[1]; + indices[1] = idx; + remaining -= idx * strides[1]; - let mut dst_dots = vec![]; - for d in 0..dst_shape.len() { - dst_dots.push(dst_shape[d + 1..].iter().product::().max(1)); - } + let idx = remaining / strides[2]; + indices[2] = idx; + remaining -= idx * strides[2]; - for i in 0..dst.len() { - let mut src_index = 0; - let mut tmp = i; - for d in 0..dst_shape.len() { - let coord = tmp / dst_dots[d]; - tmp %= dst_dots[d]; - src_index += (coord + start[d]) * src_strides[d] as usize; - } - dst[i] = src[src_index]; - } + indices[3] = remaining; + indices +} - dst +#[inline] +fn nd_index_to_offset(ndindex: [usize; 4], strides: [usize; 4]) -> usize { + ndindex[0] * strides[0] + + ndindex[1] * strides[1] + + ndindex[2] * strides[2] + + ndindex[3] * strides[3] } diff --git a/crates/ratchet-core/src/ops/reindex/permute.rs b/crates/ratchet-core/src/ops/reindex/permute.rs index b032eb8c..01f0aeee 100644 --- a/crates/ratchet-core/src/ops/reindex/permute.rs +++ b/crates/ratchet-core/src/ops/reindex/permute.rs @@ -111,9 +111,8 @@ def permute(a): run_py_prg(prg.to_string(), &[a], &[], a.dt()) } - fn run_reindex_trial(prob: PermuteProblem) -> anyhow::Result<()> { + fn run_reindex_trial(prob: PermuteProblem, device: Device) -> anyhow::Result<()> { let PermuteProblem { op } = prob; - let device = Device::request_device(DeviceRequest::GPU).unwrap(); let a = op.src.clone(); let a_gpu = a.to(&device)?; @@ -125,7 +124,14 @@ def permute(a): } #[proptest(cases = 16)] - fn test_permute(prob: PermuteProblem) { - run_reindex_trial(prob).unwrap(); + fn test_permute_gpu(prob: PermuteProblem) { + let device = Device::request_device(DeviceRequest::GPU).unwrap(); + run_reindex_trial(prob, device).unwrap(); + } + + #[proptest(cases = 16)] + fn test_permute_cpu(prob: PermuteProblem) { + let device = Device::request_device(DeviceRequest::CPU).unwrap(); + run_reindex_trial(prob, device).unwrap(); } }