From a899148ff736b5895679690dd3e5a913ce92e5ce Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Fri, 8 Nov 2024 15:13:24 -0400 Subject: [PATCH 1/6] feature: TensorIterator for both strided and contiguous iter on cpu --- crates/ratchet-core/src/cpu/reindex.rs | 17 ++- crates/ratchet-core/src/cpu/utils.rs | 199 ++++++++++++++++++++++++- crates/ratchet-core/src/shape.rs | 7 + crates/ratchet-core/src/strides.rs | 41 ++++- 4 files changed, 260 insertions(+), 4 deletions(-) diff --git a/crates/ratchet-core/src/cpu/reindex.rs b/crates/ratchet-core/src/cpu/reindex.rs index d3953484..1a68feef 100644 --- a/crates/ratchet-core/src/cpu/reindex.rs +++ b/crates/ratchet-core/src/cpu/reindex.rs @@ -1,15 +1,30 @@ use super::utils::cpu_store_result; -use crate::{CPUOperation, DType, OperationError, Reindex, Slice, Strides, Tensor, TensorDType}; +use crate::{ + Broadcast, CPUOperation, DType, OperationError, Reindex, Slice, Strides, Tensor, TensorDType, +}; use half::{bf16, f16}; impl CPUOperation for Reindex { fn apply_cpu(&self, dst: Tensor) -> Result { match self { Reindex::Slice(s) => s.apply_cpu(dst), + Reindex::Broadcast(b) => b.apply_cpu(dst), _ => todo!(), } } } +impl CPUOperation for Broadcast { + fn apply_cpu(&self, dst: Tensor) -> Result { + match dst.dt() { + DType::F32 => apply_broadcast::(self, dst), + _ => todo!(), + } + } +} + +fn apply_broadcast(b: &Broadcast, dst: Tensor) -> Result { + Ok(dst) +} impl CPUOperation for Slice { fn apply_cpu(&self, dst: Tensor) -> Result { diff --git a/crates/ratchet-core/src/cpu/utils.rs b/crates/ratchet-core/src/cpu/utils.rs index e61facde..99ba5d5b 100644 --- a/crates/ratchet-core/src/cpu/utils.rs +++ b/crates/ratchet-core/src/cpu/utils.rs @@ -1,6 +1,201 @@ -use crate::{CPUBuffer, Storage, Tensor}; -use bytemuck::NoUninit; +use crate::{CPUBuffer, Shape, Storage, Strides, Tensor}; +use bytemuck::{Contiguous, NoUninit}; +use std::ops::Range; pub fn cpu_store_result(dst: &Tensor, data: &[T]) { dst.update_storage(Storage::CPU(CPUBuffer::from_slice(data, dst.shape()))); } + +pub enum TensorIterator<'a> { + Contiguous(Range), + Strided(StridedIterator<'a>), +} + +impl<'a> TensorIterator<'a> { + fn new(shape: &'a Shape, strides: &'a Strides, offset: usize) -> Self { + let mut block_size: usize = 1; + let mut contiguous_dims: usize = 0; + for (&stride, &dim) in strides.iter().zip(shape.iter()).rev() { + if stride as usize != block_size { + break; + } + block_size *= dim as usize; + contiguous_dims += 1; + } + let index_dims = shape.rank() - contiguous_dims; + if index_dims == 0 { + Self::Contiguous(offset..block_size) + } else { + Self::Strided(StridedIterator::new(&shape, &strides, offset, block_size)) + } + } +} + +impl<'a> Iterator for TensorIterator<'a> { + type Item = usize; + + fn next(&mut self) -> Option { + match self { + Self::Contiguous(r) => r.next(), + Self::Strided(s) => s.next(), + } + } +} + +pub struct StridedIterator<'a> { + shape: &'a [usize], + strides: &'a [isize], + next_index: Option, + multi_index: Vec, + block_size: usize, + block_step: usize, +} + +impl<'a> StridedIterator<'a> { + pub fn new( + shape: &'a [usize], + strides: &'a [isize], + start_offset: usize, + block_len: usize, + ) -> Self { + Self { + shape, + strides, + next_index: if shape.iter().product::() == 0 { + None + } else { + Some(start_offset) + }, + multi_index: vec![0; shape.len()], + block_size: block_len, + block_step: 0, + } + } +} + +impl<'a> Iterator for StridedIterator<'a> { + type Item = usize; + + fn next(&mut self) -> Option { + let storage_index = match self.next_index { + None => return None, + Some(storage_index) => storage_index, + }; + + if self.block_size > 1 { + if self.block_step < self.block_size { + self.block_step += 1; + return Some(storage_index + self.block_step - 1); + } else { + self.block_step = 0; + } + } + + let mut updated = false; + let mut next_storage_index = storage_index; + for ((multi_i, max_i), stride_i) in self + .multi_index + .iter_mut() + .zip(self.shape.iter()) + .zip(self.strides.iter()) + .rev() + { + let next_i = *multi_i + 1; + if next_i < *max_i { + *multi_i = next_i; + updated = true; + next_storage_index += *stride_i as usize; + break; + } else { + next_storage_index -= *multi_i * *stride_i as usize; + *multi_i = 0 + } + } + self.next_index = if updated { + Some(next_storage_index) + } else { + None + }; + Some(storage_index) + } +} + +impl<'a> From<(&'a Shape, &'a Strides)> for StridedIterator<'a> { + fn from((shape, strides): (&'a Shape, &'a Strides)) -> Self { + StridedIterator::new(shape.as_slice(), strides.as_slice(), 0, 1) + } +} + +impl<'a> From<(&'a Shape, &'a Strides, usize)> for StridedIterator<'a> { + fn from((shape, strides, offset): (&'a Shape, &'a Strides, usize)) -> Self { + StridedIterator::new(shape.as_slice(), strides.as_slice(), offset, 1) + } +} + +#[cfg(test)] +mod tests { + use proptest::prelude::*; + use test_strategy::{proptest, Arbitrary}; + + use crate::{shape, Shape, Strides}; + + use super::{StridedIterator, TensorIterator}; + + #[derive(Debug)] + struct IterProblem { + shape: Shape, + offset: usize, + } + + impl Arbitrary for IterProblem { + type Parameters = (); + type Strategy = BoxedStrategy; + + fn arbitrary_with(_: Self::Parameters) -> Self::Strategy { + let ranges = vec![1..=2, 1..=4, 1..=256, 1..=256]; + Shape::arbitrary_with(ranges) + .prop_flat_map(|shape| (Just(shape.clone()), 0..shape.numel())) + .prop_map(|(shape, offset)| IterProblem { shape, offset }) + .boxed() + } + } + + #[proptest(cases = 16)] + fn test_tensor_iter_contiguous(prob: IterProblem) { + let shape = prob.shape; + let strides = Strides::from(&shape); + let offset = prob.offset; + + let iter = TensorIterator::new(&shape, &strides, offset); + assert!(matches!(iter, TensorIterator::Contiguous(_))); + + match iter { + TensorIterator::Contiguous(r) => assert_eq!(r, offset..shape.numel()), + _ => unreachable!(), + } + } + + #[proptest(cases = 16)] + fn test_tensor_iter_strided(prob: IterProblem) { + let mut shape = prob.shape; + let mut strides = Strides::from(&shape); + strides.transpose(); + shape.transpose(); + let offset = prob.offset; + + let iter = TensorIterator::new(&shape, &strides, offset); + assert!(matches!(iter, TensorIterator::Strided(_))); + + match iter { + TensorIterator::Strided(strided_iter) => { + let mut indices: Vec = strided_iter.collect(); + assert_eq!(indices.len(), shape.numel()); + let contiguous: Vec = (offset..shape.numel() + offset).collect(); + assert_ne!(indices, contiguous); + indices.sort(); + assert_eq!(indices, contiguous); + } + _ => unreachable!(), + } + } +} diff --git a/crates/ratchet-core/src/shape.rs b/crates/ratchet-core/src/shape.rs index c349cbb5..895e762e 100644 --- a/crates/ratchet-core/src/shape.rs +++ b/crates/ratchet-core/src/shape.rs @@ -144,6 +144,13 @@ impl Shape { } } +impl core::ops::Deref for Shape { + type Target = [usize]; + fn deref(&self) -> &Self::Target { + &self.0 + } +} + impl std::fmt::Debug for Shape { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let mut shape = format!("[{}", self.0.first().unwrap_or(&0)); diff --git a/crates/ratchet-core/src/strides.rs b/crates/ratchet-core/src/strides.rs index 11920ae0..2941a8fa 100644 --- a/crates/ratchet-core/src/strides.rs +++ b/crates/ratchet-core/src/strides.rs @@ -1,4 +1,4 @@ -use std::ops::Index; +use std::ops::{Index, IndexMut, RangeFrom, RangeTo}; use std::slice::Iter; use crate::{rvec, RVec, Shape}; @@ -29,6 +29,10 @@ impl Strides { pub fn rank(&self) -> usize { self.0.len() } + + pub fn as_slice(&self) -> &[isize] { + &self.0 + } } impl std::fmt::Debug for Strides { @@ -41,6 +45,13 @@ impl std::fmt::Debug for Strides { } } +impl core::ops::Deref for Strides { + type Target = [isize]; + fn deref(&self) -> &Self::Target { + &self.0 + } +} + impl Index for Strides { type Output = isize; @@ -49,6 +60,28 @@ impl Index for Strides { } } +impl IndexMut for Strides { + fn index_mut(&mut self, index: usize) -> &mut Self::Output { + &mut self.0[index] + } +} + +impl Index> for Strides { + type Output = [isize]; + + fn index(&self, index: RangeFrom) -> &Self::Output { + &self.0[index] + } +} + +impl Index> for Strides { + type Output = [isize]; + + fn index(&self, index: RangeTo) -> &Self::Output { + &self.0[index] + } +} + impl From<&Shape> for Strides { fn from(shape: &Shape) -> Self { let mut strides = rvec![]; @@ -68,6 +101,12 @@ impl From> for Strides { } } +impl From<&[isize]> for Strides { + fn from(strides: &[isize]) -> Self { + Self(strides.into()) + } +} + impl From<&Strides> for [u32; 3] { fn from(strides: &Strides) -> Self { assert!(strides.0.len() <= 3); From 0e0a3014b7d18680db974217df4e72e4ffa39d79 Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Sun, 10 Nov 2024 21:07:08 -0400 Subject: [PATCH 2/6] chore: bonus strided iter test --- crates/ratchet-core/src/cpu/utils.rs | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/crates/ratchet-core/src/cpu/utils.rs b/crates/ratchet-core/src/cpu/utils.rs index 99ba5d5b..f1f565ff 100644 --- a/crates/ratchet-core/src/cpu/utils.rs +++ b/crates/ratchet-core/src/cpu/utils.rs @@ -6,13 +6,14 @@ pub fn cpu_store_result(dst: &Tensor, data: &[T]) { dst.update_storage(Storage::CPU(CPUBuffer::from_slice(data, dst.shape()))); } +#[derive(Clone)] pub enum TensorIterator<'a> { Contiguous(Range), Strided(StridedIterator<'a>), } impl<'a> TensorIterator<'a> { - fn new(shape: &'a Shape, strides: &'a Strides, offset: usize) -> Self { + pub fn new(shape: &'a Shape, strides: &'a Strides, offset: usize) -> Self { let mut block_size: usize = 1; let mut contiguous_dims: usize = 0; for (&stride, &dim) in strides.iter().zip(shape.iter()).rev() { @@ -42,6 +43,7 @@ impl<'a> Iterator for TensorIterator<'a> { } } +#[derive(Clone)] pub struct StridedIterator<'a> { shape: &'a [usize], strides: &'a [isize], @@ -198,4 +200,20 @@ mod tests { _ => unreachable!(), } } + + #[test] + fn test_tensor_iter_strided_sanity() { + let mut shape = shape!(2, 4, 3); + let mut strides = Strides::from(&shape); + strides.transpose(); + shape.transpose(); + let offset = 2; + + let iter = TensorIterator::new(&shape, &strides, offset); + let actual: Vec = iter.collect(); + let expected = vec![ + 2, 5, 8, 11, 3, 6, 9, 12, 4, 7, 10, 13, 14, 17, 20, 23, 15, 18, 21, 24, 16, 19, 22, 25, + ]; + assert_eq!(actual, expected); + } } From 4199ca47aef1533578dced59702a3b7fbe21d51a Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Tue, 12 Nov 2024 18:25:35 -0400 Subject: [PATCH 3/6] feature: cpu broadcast works! --- crates/ratchet-core/src/cpu/reindex.rs | 132 +++++++++++++++++- .../ratchet-core/src/ops/reindex/broadcast.rs | 19 ++- crates/ratchet-core/src/ops/reindex/slice.rs | 2 +- 3 files changed, 141 insertions(+), 12 deletions(-) diff --git a/crates/ratchet-core/src/cpu/reindex.rs b/crates/ratchet-core/src/cpu/reindex.rs index 1a68feef..332f67e0 100644 --- a/crates/ratchet-core/src/cpu/reindex.rs +++ b/crates/ratchet-core/src/cpu/reindex.rs @@ -1,6 +1,10 @@ -use super::utils::cpu_store_result; +use super::utils::{ + cpu_store_result, TensorIterator, + TensorIterator::{Contiguous, Strided}, +}; use crate::{ - Broadcast, CPUOperation, DType, OperationError, Reindex, Slice, Strides, Tensor, TensorDType, + Broadcast, CPUOperation, DType, OperationError, Reindex, Shape, Slice, Strides, Tensor, + TensorDType, }; use half::{bf16, f16}; @@ -17,15 +21,128 @@ impl CPUOperation for Broadcast { fn apply_cpu(&self, dst: Tensor) -> Result { match dst.dt() { DType::F32 => apply_broadcast::(self, dst), + DType::BF16 => apply_broadcast::(self, dst), + DType::F16 => apply_broadcast::(self, dst), + DType::I32 => apply_broadcast::(self, dst), + DType::U32 => apply_broadcast::(self, dst), _ => todo!(), } } } fn apply_broadcast(b: &Broadcast, dst: Tensor) -> Result { + let result = broadcast(&b.src.to_vec::()?, b.src.shape(), b.to()); + cpu_store_result(&dst, &result); Ok(dst) } +fn get_contiguous_offsets( + shape: &Shape, + strides: &Strides, +) -> Option<(usize, usize, usize, usize)> { + let mut left_broadcast = 1; + let mut right_broadcast = 1; + let dims = shape.to_vec(); + let strides = strides.to_vec(); + let mut start_cont = 0; + let mut end_cont = dims.len(); + for (&s, &d) in strides.iter().zip(dims.iter()) { + if s != 0 { + break; + } + start_cont += 1; + left_broadcast *= d; + } + if start_cont == dims.len() { + return Some((0, 1, left_broadcast, 1)); + } + for (&s, &d) in strides.iter().zip(dims.iter()).rev() { + if s != 0 { + break; + } + end_cont -= 1; + right_broadcast *= d; + } + // Check that the inner dims are contiguous + let strides = &strides[start_cont..end_cont]; + let dims = &dims[start_cont..end_cont]; + let mut len = 1; + for (&stride, &dim) in strides.iter().zip(dims.iter()).rev() { + if stride as usize != len { + return None; + } + len *= dim; + } + + Some((0, len, left_broadcast, right_broadcast)) +} + +fn offset_to_ndindex(offset: usize, strides: &[usize]) -> Vec { + let mut indices = vec![0; strides.len()]; + let mut remaining = offset; + + for i in 0..strides.len() - 1 { + let stride = strides[i]; + let idx = remaining / stride; + indices[i] = idx; + remaining -= idx * stride; + } + indices[strides.len() - 1] = remaining; + indices +} + +fn nd_index_to_offset(ndindex: &[usize], strides: &[usize]) -> usize { + ndindex.iter().zip(strides.iter()).map(|(x, y)| x * y).sum() +} + +pub(crate) fn broadcast(src: &[T], src_shape: &Shape, dst_shape: &Shape) -> Vec { + let src_strides = Strides::from(src_shape); + let dst_strides = Strides::from(dst_shape); + let mut result = vec![T::zero(); dst_shape.numel()]; + + let dst_shape = dst_shape.to_vec(); + + let src_strides: Vec = src_strides.as_slice().iter().map(|x| *x as usize).collect(); + let dst_strides: Vec = dst_strides.as_slice().iter().map(|x| *x as usize).collect(); + + if src_shape.is_scalar() { + // Life is simple + result.fill(src[0]); + } else if src_shape.is_vector() { + // If from is a vector and the first dimension is the broadcasting dimension + if src_shape[0] > 1 && src_shape[0] == dst_shape[0] { + let chunk_size = result.len() / src_shape.numel(); + + (0..result.len()) + .step_by(chunk_size) + .enumerate() + .for_each(|(i, chunk)| { + result[chunk..chunk + chunk_size].fill(src[i]); + }); + } else { + for i in 0..result.len() { + let dst_index = offset_to_ndindex(i, &dst_strides); + let src_index: Vec = (0..src_shape.len()) + .map(|x| if src_shape[x] == 1 { 0 } else { dst_index[x] }) + .collect(); + let src_offset = nd_index_to_offset(&src_index, &src_strides); + result[i] = src[src_offset] + } + } + } else { + for i in 0..result.len() { + let dst_index = offset_to_ndindex(i, &dst_strides); + let src_index: Vec = (0..src_shape.len()) + .map(|x| if src_shape[x] == 1 { 0 } else { dst_index[x] }) + .collect(); + let src_offset = nd_index_to_offset(&src_index, &src_strides); + result[i] = src[src_offset] + } + } + + result +} + impl CPUOperation for Slice { fn apply_cpu(&self, dst: Tensor) -> Result { match dst.dt() { @@ -64,12 +181,17 @@ pub(crate) fn slice( let mut dst = vec![T::zero(); dst_numel]; - for i in 0..dst_numel { + let mut dst_dots = vec![]; + for d in 0..dst_shape.len() { + dst_dots.push(dst_shape[d + 1..].iter().product::().max(1)); + } + + for i in 0..dst.len() { let mut src_index = 0; let mut tmp = i; for d in 0..dst_shape.len() { - let coord = tmp / dst_shape[d + 1..].iter().product::().max(1); - tmp %= dst_shape[d + 1..].iter().product::().max(1); + let coord = tmp / dst_dots[d]; + tmp %= dst_dots[d]; src_index += (coord + start[d]) * src_strides[d] as usize; } dst[i] = src[src_index]; diff --git a/crates/ratchet-core/src/ops/reindex/broadcast.rs b/crates/ratchet-core/src/ops/reindex/broadcast.rs index 0055753f..78ce145a 100644 --- a/crates/ratchet-core/src/ops/reindex/broadcast.rs +++ b/crates/ratchet-core/src/ops/reindex/broadcast.rs @@ -64,7 +64,7 @@ impl Operation for Broadcast { "Broadcast" } - //For rules, see https://numpy.org/doc/stable/user/basics.broadcasting.html + // For rules, see https://numpy.org/doc/stable/user/basics.broadcasting.html fn compute_view(&self) -> Result { let src_shape = self.src.shape(); @@ -140,11 +140,10 @@ def slice(a): run_py_prg(prg.to_string(), &[a], &[], a.dt()) } - fn run_reindex_trial(prob: BroadcastProblem) -> anyhow::Result<()> { + fn run_reindex_trial(prob: BroadcastProblem, device: Device) -> anyhow::Result<()> { println!("\n\nBroadcast problem: {:?}", prob); let BroadcastProblem { op } = prob; let a = op.src.clone(); - let device = Device::request_device(DeviceRequest::GPU).unwrap(); let a_gpu = a.to(&device)?; let ground = ground_truth(&a, &op.to.as_torch())?; @@ -155,18 +154,26 @@ def slice(a): } #[proptest(cases = 16)] - fn test_broadcast(prob: BroadcastProblem) { - run_reindex_trial(prob).unwrap(); + fn test_broadcast_gpu(prob: BroadcastProblem) { + let device = Device::request_device(DeviceRequest::GPU).unwrap(); + run_reindex_trial(prob, device).unwrap(); + } + + #[proptest(cases = 16)] + fn test_broadcast_cpu(prob: BroadcastProblem) { + let device = Device::request_device(DeviceRequest::CPU).unwrap(); + run_reindex_trial(prob, device).unwrap(); } #[test] fn debug_broadcast() { + let device = Device::request_device(DeviceRequest::GPU).unwrap(); let prob = BroadcastProblem { op: Broadcast::new( Tensor::randn::(shape![1], Device::CPU), shape![4, 32, 128, 128], ), }; - run_reindex_trial(prob).unwrap(); + run_reindex_trial(prob, device).unwrap(); } } diff --git a/crates/ratchet-core/src/ops/reindex/slice.rs b/crates/ratchet-core/src/ops/reindex/slice.rs index fce43db6..508f2efd 100644 --- a/crates/ratchet-core/src/ops/reindex/slice.rs +++ b/crates/ratchet-core/src/ops/reindex/slice.rs @@ -178,7 +178,7 @@ def slice(a): run_reindex_trial(prob, device).unwrap(); } - #[proptest(cases = 16)] + #[proptest(cases = 1)] fn test_slice_cpu(prob: SliceProblem) { let _ = env_logger::builder().is_test(true).try_init(); let device = Device::request_device(DeviceRequest::CPU).unwrap(); From 72975b48a95bf25b553012f59ec4c4a128fd2f7a Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Wed, 13 Nov 2024 00:56:55 -0400 Subject: [PATCH 4/6] chore: remove unused contiguous offsets fn --- crates/ratchet-core/src/cpu/reindex.rs | 41 -------------------------- 1 file changed, 41 deletions(-) diff --git a/crates/ratchet-core/src/cpu/reindex.rs b/crates/ratchet-core/src/cpu/reindex.rs index 332f67e0..433dc8aa 100644 --- a/crates/ratchet-core/src/cpu/reindex.rs +++ b/crates/ratchet-core/src/cpu/reindex.rs @@ -36,47 +36,6 @@ fn apply_broadcast(b: &Broadcast, dst: Tensor) -> Result Option<(usize, usize, usize, usize)> { - let mut left_broadcast = 1; - let mut right_broadcast = 1; - let dims = shape.to_vec(); - let strides = strides.to_vec(); - let mut start_cont = 0; - let mut end_cont = dims.len(); - for (&s, &d) in strides.iter().zip(dims.iter()) { - if s != 0 { - break; - } - start_cont += 1; - left_broadcast *= d; - } - if start_cont == dims.len() { - return Some((0, 1, left_broadcast, 1)); - } - for (&s, &d) in strides.iter().zip(dims.iter()).rev() { - if s != 0 { - break; - } - end_cont -= 1; - right_broadcast *= d; - } - // Check that the inner dims are contiguous - let strides = &strides[start_cont..end_cont]; - let dims = &dims[start_cont..end_cont]; - let mut len = 1; - for (&stride, &dim) in strides.iter().zip(dims.iter()).rev() { - if stride as usize != len { - return None; - } - len *= dim; - } - - Some((0, len, left_broadcast, right_broadcast)) -} - fn offset_to_ndindex(offset: usize, strides: &[usize]) -> Vec { let mut indices = vec![0; strides.len()]; let mut remaining = offset; From dad330014df2c1072878f5fe57f6141999ddbc63 Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Wed, 13 Nov 2024 01:05:35 -0400 Subject: [PATCH 5/6] chore: tidy up broadcast --- crates/ratchet-core/src/cpu/reindex.rs | 74 ++++++++++++++------------ 1 file changed, 40 insertions(+), 34 deletions(-) diff --git a/crates/ratchet-core/src/cpu/reindex.rs b/crates/ratchet-core/src/cpu/reindex.rs index 433dc8aa..ced2baa4 100644 --- a/crates/ratchet-core/src/cpu/reindex.rs +++ b/crates/ratchet-core/src/cpu/reindex.rs @@ -36,24 +36,6 @@ fn apply_broadcast(b: &Broadcast, dst: Tensor) -> Result Vec { - let mut indices = vec![0; strides.len()]; - let mut remaining = offset; - - for i in 0..strides.len() - 1 { - let stride = strides[i]; - let idx = remaining / stride; - indices[i] = idx; - remaining -= idx * stride; - } - indices[strides.len() - 1] = remaining; - indices -} - -fn nd_index_to_offset(ndindex: &[usize], strides: &[usize]) -> usize { - ndindex.iter().zip(strides.iter()).map(|(x, y)| x * y).sum() -} - pub(crate) fn broadcast(src: &[T], src_shape: &Shape, dst_shape: &Shape) -> Vec { let src_strides = Strides::from(src_shape); let dst_strides = Strides::from(dst_shape); @@ -79,29 +61,53 @@ pub(crate) fn broadcast(src: &[T], src_shape: &Shape, dst_shape: result[chunk..chunk + chunk_size].fill(src[i]); }); } else { - for i in 0..result.len() { - let dst_index = offset_to_ndindex(i, &dst_strides); - let src_index: Vec = (0..src_shape.len()) - .map(|x| if src_shape[x] == 1 { 0 } else { dst_index[x] }) - .collect(); - let src_offset = nd_index_to_offset(&src_index, &src_strides); - result[i] = src[src_offset] - } + generic_broadcast(src, &mut result, src_shape, &src_strides, &dst_strides) } } else { - for i in 0..result.len() { - let dst_index = offset_to_ndindex(i, &dst_strides); - let src_index: Vec = (0..src_shape.len()) - .map(|x| if src_shape[x] == 1 { 0 } else { dst_index[x] }) - .collect(); - let src_offset = nd_index_to_offset(&src_index, &src_strides); - result[i] = src[src_offset] - } + generic_broadcast(src, &mut result, src_shape, &src_strides, &dst_strides) } result } +// TODO: Optimize. +// This generic implementation is almost a direct copy from the gpu impl, +// and can definitely be way more performant. +fn generic_broadcast( + src: &[T], + result: &mut [T], + src_shape: &[usize], + src_strides: &[usize], + dst_strides: &[usize], +) { + fn offset_to_ndindex(offset: usize, strides: &[usize]) -> Vec { + let mut indices = vec![0; strides.len()]; + let mut remaining = offset; + + for i in 0..strides.len() - 1 { + let stride = strides[i]; + let idx = remaining / stride; + indices[i] = idx; + remaining -= idx * stride; + } + indices[strides.len() - 1] = remaining; + indices + } + + fn nd_index_to_offset(ndindex: &[usize], strides: &[usize]) -> usize { + ndindex.iter().zip(strides.iter()).map(|(x, y)| x * y).sum() + } + + for i in 0..result.len() { + let dst_index = offset_to_ndindex(i, &dst_strides); + let src_index: Vec = (0..src_shape.len()) + .map(|x| if src_shape[x] == 1 { 0 } else { dst_index[x] }) + .collect(); + let src_offset = nd_index_to_offset(&src_index, &src_strides); + result[i] = src[src_offset] + } +} + impl CPUOperation for Slice { fn apply_cpu(&self, dst: Tensor) -> Result { match dst.dt() { From 74a35653122b58d507b0a55fcddbccd5e58b50d2 Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Wed, 13 Nov 2024 01:39:54 -0400 Subject: [PATCH 6/6] chore: unroll generic broadcast --- crates/ratchet-core/src/cpu/reindex.rs | 81 +++++++++++++++++--------- crates/ratchet-core/src/strides.rs | 11 ++++ 2 files changed, 64 insertions(+), 28 deletions(-) diff --git a/crates/ratchet-core/src/cpu/reindex.rs b/crates/ratchet-core/src/cpu/reindex.rs index ced2baa4..4ba7edad 100644 --- a/crates/ratchet-core/src/cpu/reindex.rs +++ b/crates/ratchet-core/src/cpu/reindex.rs @@ -37,15 +37,8 @@ fn apply_broadcast(b: &Broadcast, dst: Tensor) -> Result(src: &[T], src_shape: &Shape, dst_shape: &Shape) -> Vec { - let src_strides = Strides::from(src_shape); - let dst_strides = Strides::from(dst_shape); let mut result = vec![T::zero(); dst_shape.numel()]; - let dst_shape = dst_shape.to_vec(); - - let src_strides: Vec = src_strides.as_slice().iter().map(|x| *x as usize).collect(); - let dst_strides: Vec = dst_strides.as_slice().iter().map(|x| *x as usize).collect(); - if src_shape.is_scalar() { // Life is simple result.fill(src[0]); @@ -61,10 +54,10 @@ pub(crate) fn broadcast(src: &[T], src_shape: &Shape, dst_shape: result[chunk..chunk + chunk_size].fill(src[i]); }); } else { - generic_broadcast(src, &mut result, src_shape, &src_strides, &dst_strides) + generic_broadcast(src, &mut result, src_shape, dst_shape) } } else { - generic_broadcast(src, &mut result, src_shape, &src_strides, &dst_strides) + generic_broadcast(src, &mut result, src_shape, dst_shape) } result @@ -76,34 +69,66 @@ pub(crate) fn broadcast(src: &[T], src_shape: &Shape, dst_shape: fn generic_broadcast( src: &[T], result: &mut [T], - src_shape: &[usize], - src_strides: &[usize], - dst_strides: &[usize], + src_shape: &Shape, + dst_shape: &Shape, ) { - fn offset_to_ndindex(offset: usize, strides: &[usize]) -> Vec { - let mut indices = vec![0; strides.len()]; + // We now know that these will always be len 4, same as gpu impl. + let src_shape = &Shape::promote(src_shape.clone(), 4); + let dst_shape = &Shape::promote(dst_shape.clone(), 4); + + let src_strides = &Strides::from(src_shape); + let dst_strides = &Strides::from(dst_shape); + + let src_shape: [usize; 4] = src_shape.try_into().unwrap(); + let src_strides: [usize; 4] = src_strides.try_into().unwrap(); + let dst_strides: [usize; 4] = dst_strides.try_into().unwrap(); + + fn offset_to_ndindex(offset: usize, strides: [usize; 4]) -> [usize; 4] { + let mut indices = [0; 4]; let mut remaining = offset; - for i in 0..strides.len() - 1 { - let stride = strides[i]; - let idx = remaining / stride; - indices[i] = idx; - remaining -= idx * stride; - } - indices[strides.len() - 1] = remaining; + let idx = remaining / strides[0]; + indices[0] = idx; + remaining -= idx * strides[0]; + + let idx = remaining / strides[1]; + indices[1] = idx; + remaining -= idx * strides[1]; + + let idx = remaining / strides[2]; + indices[2] = idx; + remaining -= idx * strides[2]; + + indices[3] = remaining; indices } - fn nd_index_to_offset(ndindex: &[usize], strides: &[usize]) -> usize { - ndindex.iter().zip(strides.iter()).map(|(x, y)| x * y).sum() + fn select(a: [usize; 4], b: [usize; 4], t: [bool; 4]) -> [usize; 4] { + let mut result = [0; 4]; + result[0] = if t[0] { a[0] } else { b[0] }; + result[1] = if t[1] { a[1] } else { b[1] }; + result[2] = if t[2] { a[2] } else { b[2] }; + result[3] = if t[3] { a[3] } else { b[3] }; + result + } + + fn nd_index_to_offset(ndindex: [usize; 4], strides: [usize; 4]) -> usize { + ndindex[0] * strides[0] + + ndindex[1] * strides[1] + + ndindex[2] * strides[2] + + ndindex[3] * strides[3] } + let shape_onedim_lookup: [bool; 4] = [ + src_shape[0] != 1, + src_shape[1] != 1, + src_shape[2] != 1, + src_shape[3] != 1, + ]; for i in 0..result.len() { - let dst_index = offset_to_ndindex(i, &dst_strides); - let src_index: Vec = (0..src_shape.len()) - .map(|x| if src_shape[x] == 1 { 0 } else { dst_index[x] }) - .collect(); - let src_offset = nd_index_to_offset(&src_index, &src_strides); + let dst_index = offset_to_ndindex(i, dst_strides); + let src_index = select(dst_index, [0; 4], shape_onedim_lookup); + let src_offset = nd_index_to_offset(src_index, src_strides); result[i] = src[src_offset] } } diff --git a/crates/ratchet-core/src/strides.rs b/crates/ratchet-core/src/strides.rs index 2941a8fa..2ca653f9 100644 --- a/crates/ratchet-core/src/strides.rs +++ b/crates/ratchet-core/src/strides.rs @@ -136,6 +136,17 @@ impl From<&Strides> for [u32; 4] { } } +impl From<&Strides> for [usize; 4] { + fn from(strides: &Strides) -> Self { + assert!(strides.0.len() <= 4); + let mut array = [0; 4]; + for (i, &stride) in strides.0.iter().enumerate() { + array[i] = stride as usize; + } + array + } +} + impl From<&Strides> for glam::UVec4 { fn from(strides: &Strides) -> Self { let array: [u32; 4] = strides.into();