diff --git a/crates/ratchet-core/src/cpu/reindex.rs b/crates/ratchet-core/src/cpu/reindex.rs index d3953484..4ba7edad 100644 --- a/crates/ratchet-core/src/cpu/reindex.rs +++ b/crates/ratchet-core/src/cpu/reindex.rs @@ -1,15 +1,137 @@ -use super::utils::cpu_store_result; -use crate::{CPUOperation, DType, OperationError, Reindex, Slice, Strides, Tensor, TensorDType}; +use super::utils::{ + cpu_store_result, TensorIterator, + TensorIterator::{Contiguous, Strided}, +}; +use crate::{ + Broadcast, CPUOperation, DType, OperationError, Reindex, Shape, Slice, Strides, Tensor, + TensorDType, +}; use half::{bf16, f16}; impl CPUOperation for Reindex { fn apply_cpu(&self, dst: Tensor) -> Result { match self { Reindex::Slice(s) => s.apply_cpu(dst), + Reindex::Broadcast(b) => b.apply_cpu(dst), _ => todo!(), } } } +impl CPUOperation for Broadcast { + fn apply_cpu(&self, dst: Tensor) -> Result { + match dst.dt() { + DType::F32 => apply_broadcast::(self, dst), + DType::BF16 => apply_broadcast::(self, dst), + DType::F16 => apply_broadcast::(self, dst), + DType::I32 => apply_broadcast::(self, dst), + DType::U32 => apply_broadcast::(self, dst), + _ => todo!(), + } + } +} + +fn apply_broadcast(b: &Broadcast, dst: Tensor) -> Result { + let result = broadcast(&b.src.to_vec::()?, b.src.shape(), b.to()); + cpu_store_result(&dst, &result); + Ok(dst) +} + +pub(crate) fn broadcast(src: &[T], src_shape: &Shape, dst_shape: &Shape) -> Vec { + let mut result = vec![T::zero(); dst_shape.numel()]; + + if src_shape.is_scalar() { + // Life is simple + result.fill(src[0]); + } else if src_shape.is_vector() { + // If from is a vector and the first dimension is the broadcasting dimension + if src_shape[0] > 1 && src_shape[0] == dst_shape[0] { + let chunk_size = result.len() / src_shape.numel(); + + (0..result.len()) + .step_by(chunk_size) + .enumerate() + .for_each(|(i, chunk)| { + result[chunk..chunk + chunk_size].fill(src[i]); + }); + } else { + generic_broadcast(src, &mut result, src_shape, dst_shape) + } + } else { + generic_broadcast(src, &mut result, src_shape, dst_shape) + } + + result +} + +// TODO: Optimize. +// This generic implementation is almost a direct copy from the gpu impl, +// and can definitely be way more performant. +fn generic_broadcast( + src: &[T], + result: &mut [T], + src_shape: &Shape, + dst_shape: &Shape, +) { + // We now know that these will always be len 4, same as gpu impl. + let src_shape = &Shape::promote(src_shape.clone(), 4); + let dst_shape = &Shape::promote(dst_shape.clone(), 4); + + let src_strides = &Strides::from(src_shape); + let dst_strides = &Strides::from(dst_shape); + + let src_shape: [usize; 4] = src_shape.try_into().unwrap(); + let src_strides: [usize; 4] = src_strides.try_into().unwrap(); + let dst_strides: [usize; 4] = dst_strides.try_into().unwrap(); + + fn offset_to_ndindex(offset: usize, strides: [usize; 4]) -> [usize; 4] { + let mut indices = [0; 4]; + let mut remaining = offset; + + let idx = remaining / strides[0]; + indices[0] = idx; + remaining -= idx * strides[0]; + + let idx = remaining / strides[1]; + indices[1] = idx; + remaining -= idx * strides[1]; + + let idx = remaining / strides[2]; + indices[2] = idx; + remaining -= idx * strides[2]; + + indices[3] = remaining; + indices + } + + fn select(a: [usize; 4], b: [usize; 4], t: [bool; 4]) -> [usize; 4] { + let mut result = [0; 4]; + result[0] = if t[0] { a[0] } else { b[0] }; + result[1] = if t[1] { a[1] } else { b[1] }; + result[2] = if t[2] { a[2] } else { b[2] }; + result[3] = if t[3] { a[3] } else { b[3] }; + result + } + + fn nd_index_to_offset(ndindex: [usize; 4], strides: [usize; 4]) -> usize { + ndindex[0] * strides[0] + + ndindex[1] * strides[1] + + ndindex[2] * strides[2] + + ndindex[3] * strides[3] + } + + let shape_onedim_lookup: [bool; 4] = [ + src_shape[0] != 1, + src_shape[1] != 1, + src_shape[2] != 1, + src_shape[3] != 1, + ]; + for i in 0..result.len() { + let dst_index = offset_to_ndindex(i, dst_strides); + let src_index = select(dst_index, [0; 4], shape_onedim_lookup); + let src_offset = nd_index_to_offset(src_index, src_strides); + result[i] = src[src_offset] + } +} impl CPUOperation for Slice { fn apply_cpu(&self, dst: Tensor) -> Result { @@ -49,12 +171,17 @@ pub(crate) fn slice( let mut dst = vec![T::zero(); dst_numel]; - for i in 0..dst_numel { + let mut dst_dots = vec![]; + for d in 0..dst_shape.len() { + dst_dots.push(dst_shape[d + 1..].iter().product::().max(1)); + } + + for i in 0..dst.len() { let mut src_index = 0; let mut tmp = i; for d in 0..dst_shape.len() { - let coord = tmp / dst_shape[d + 1..].iter().product::().max(1); - tmp %= dst_shape[d + 1..].iter().product::().max(1); + let coord = tmp / dst_dots[d]; + tmp %= dst_dots[d]; src_index += (coord + start[d]) * src_strides[d] as usize; } dst[i] = src[src_index]; diff --git a/crates/ratchet-core/src/cpu/utils.rs b/crates/ratchet-core/src/cpu/utils.rs index e61facde..f1f565ff 100644 --- a/crates/ratchet-core/src/cpu/utils.rs +++ b/crates/ratchet-core/src/cpu/utils.rs @@ -1,6 +1,219 @@ -use crate::{CPUBuffer, Storage, Tensor}; -use bytemuck::NoUninit; +use crate::{CPUBuffer, Shape, Storage, Strides, Tensor}; +use bytemuck::{Contiguous, NoUninit}; +use std::ops::Range; pub fn cpu_store_result(dst: &Tensor, data: &[T]) { dst.update_storage(Storage::CPU(CPUBuffer::from_slice(data, dst.shape()))); } + +#[derive(Clone)] +pub enum TensorIterator<'a> { + Contiguous(Range), + Strided(StridedIterator<'a>), +} + +impl<'a> TensorIterator<'a> { + pub fn new(shape: &'a Shape, strides: &'a Strides, offset: usize) -> Self { + let mut block_size: usize = 1; + let mut contiguous_dims: usize = 0; + for (&stride, &dim) in strides.iter().zip(shape.iter()).rev() { + if stride as usize != block_size { + break; + } + block_size *= dim as usize; + contiguous_dims += 1; + } + let index_dims = shape.rank() - contiguous_dims; + if index_dims == 0 { + Self::Contiguous(offset..block_size) + } else { + Self::Strided(StridedIterator::new(&shape, &strides, offset, block_size)) + } + } +} + +impl<'a> Iterator for TensorIterator<'a> { + type Item = usize; + + fn next(&mut self) -> Option { + match self { + Self::Contiguous(r) => r.next(), + Self::Strided(s) => s.next(), + } + } +} + +#[derive(Clone)] +pub struct StridedIterator<'a> { + shape: &'a [usize], + strides: &'a [isize], + next_index: Option, + multi_index: Vec, + block_size: usize, + block_step: usize, +} + +impl<'a> StridedIterator<'a> { + pub fn new( + shape: &'a [usize], + strides: &'a [isize], + start_offset: usize, + block_len: usize, + ) -> Self { + Self { + shape, + strides, + next_index: if shape.iter().product::() == 0 { + None + } else { + Some(start_offset) + }, + multi_index: vec![0; shape.len()], + block_size: block_len, + block_step: 0, + } + } +} + +impl<'a> Iterator for StridedIterator<'a> { + type Item = usize; + + fn next(&mut self) -> Option { + let storage_index = match self.next_index { + None => return None, + Some(storage_index) => storage_index, + }; + + if self.block_size > 1 { + if self.block_step < self.block_size { + self.block_step += 1; + return Some(storage_index + self.block_step - 1); + } else { + self.block_step = 0; + } + } + + let mut updated = false; + let mut next_storage_index = storage_index; + for ((multi_i, max_i), stride_i) in self + .multi_index + .iter_mut() + .zip(self.shape.iter()) + .zip(self.strides.iter()) + .rev() + { + let next_i = *multi_i + 1; + if next_i < *max_i { + *multi_i = next_i; + updated = true; + next_storage_index += *stride_i as usize; + break; + } else { + next_storage_index -= *multi_i * *stride_i as usize; + *multi_i = 0 + } + } + self.next_index = if updated { + Some(next_storage_index) + } else { + None + }; + Some(storage_index) + } +} + +impl<'a> From<(&'a Shape, &'a Strides)> for StridedIterator<'a> { + fn from((shape, strides): (&'a Shape, &'a Strides)) -> Self { + StridedIterator::new(shape.as_slice(), strides.as_slice(), 0, 1) + } +} + +impl<'a> From<(&'a Shape, &'a Strides, usize)> for StridedIterator<'a> { + fn from((shape, strides, offset): (&'a Shape, &'a Strides, usize)) -> Self { + StridedIterator::new(shape.as_slice(), strides.as_slice(), offset, 1) + } +} + +#[cfg(test)] +mod tests { + use proptest::prelude::*; + use test_strategy::{proptest, Arbitrary}; + + use crate::{shape, Shape, Strides}; + + use super::{StridedIterator, TensorIterator}; + + #[derive(Debug)] + struct IterProblem { + shape: Shape, + offset: usize, + } + + impl Arbitrary for IterProblem { + type Parameters = (); + type Strategy = BoxedStrategy; + + fn arbitrary_with(_: Self::Parameters) -> Self::Strategy { + let ranges = vec![1..=2, 1..=4, 1..=256, 1..=256]; + Shape::arbitrary_with(ranges) + .prop_flat_map(|shape| (Just(shape.clone()), 0..shape.numel())) + .prop_map(|(shape, offset)| IterProblem { shape, offset }) + .boxed() + } + } + + #[proptest(cases = 16)] + fn test_tensor_iter_contiguous(prob: IterProblem) { + let shape = prob.shape; + let strides = Strides::from(&shape); + let offset = prob.offset; + + let iter = TensorIterator::new(&shape, &strides, offset); + assert!(matches!(iter, TensorIterator::Contiguous(_))); + + match iter { + TensorIterator::Contiguous(r) => assert_eq!(r, offset..shape.numel()), + _ => unreachable!(), + } + } + + #[proptest(cases = 16)] + fn test_tensor_iter_strided(prob: IterProblem) { + let mut shape = prob.shape; + let mut strides = Strides::from(&shape); + strides.transpose(); + shape.transpose(); + let offset = prob.offset; + + let iter = TensorIterator::new(&shape, &strides, offset); + assert!(matches!(iter, TensorIterator::Strided(_))); + + match iter { + TensorIterator::Strided(strided_iter) => { + let mut indices: Vec = strided_iter.collect(); + assert_eq!(indices.len(), shape.numel()); + let contiguous: Vec = (offset..shape.numel() + offset).collect(); + assert_ne!(indices, contiguous); + indices.sort(); + assert_eq!(indices, contiguous); + } + _ => unreachable!(), + } + } + + #[test] + fn test_tensor_iter_strided_sanity() { + let mut shape = shape!(2, 4, 3); + let mut strides = Strides::from(&shape); + strides.transpose(); + shape.transpose(); + let offset = 2; + + let iter = TensorIterator::new(&shape, &strides, offset); + let actual: Vec = iter.collect(); + let expected = vec![ + 2, 5, 8, 11, 3, 6, 9, 12, 4, 7, 10, 13, 14, 17, 20, 23, 15, 18, 21, 24, 16, 19, 22, 25, + ]; + assert_eq!(actual, expected); + } +} diff --git a/crates/ratchet-core/src/ops/reindex/broadcast.rs b/crates/ratchet-core/src/ops/reindex/broadcast.rs index 0055753f..78ce145a 100644 --- a/crates/ratchet-core/src/ops/reindex/broadcast.rs +++ b/crates/ratchet-core/src/ops/reindex/broadcast.rs @@ -64,7 +64,7 @@ impl Operation for Broadcast { "Broadcast" } - //For rules, see https://numpy.org/doc/stable/user/basics.broadcasting.html + // For rules, see https://numpy.org/doc/stable/user/basics.broadcasting.html fn compute_view(&self) -> Result { let src_shape = self.src.shape(); @@ -140,11 +140,10 @@ def slice(a): run_py_prg(prg.to_string(), &[a], &[], a.dt()) } - fn run_reindex_trial(prob: BroadcastProblem) -> anyhow::Result<()> { + fn run_reindex_trial(prob: BroadcastProblem, device: Device) -> anyhow::Result<()> { println!("\n\nBroadcast problem: {:?}", prob); let BroadcastProblem { op } = prob; let a = op.src.clone(); - let device = Device::request_device(DeviceRequest::GPU).unwrap(); let a_gpu = a.to(&device)?; let ground = ground_truth(&a, &op.to.as_torch())?; @@ -155,18 +154,26 @@ def slice(a): } #[proptest(cases = 16)] - fn test_broadcast(prob: BroadcastProblem) { - run_reindex_trial(prob).unwrap(); + fn test_broadcast_gpu(prob: BroadcastProblem) { + let device = Device::request_device(DeviceRequest::GPU).unwrap(); + run_reindex_trial(prob, device).unwrap(); + } + + #[proptest(cases = 16)] + fn test_broadcast_cpu(prob: BroadcastProblem) { + let device = Device::request_device(DeviceRequest::CPU).unwrap(); + run_reindex_trial(prob, device).unwrap(); } #[test] fn debug_broadcast() { + let device = Device::request_device(DeviceRequest::GPU).unwrap(); let prob = BroadcastProblem { op: Broadcast::new( Tensor::randn::(shape![1], Device::CPU), shape![4, 32, 128, 128], ), }; - run_reindex_trial(prob).unwrap(); + run_reindex_trial(prob, device).unwrap(); } } diff --git a/crates/ratchet-core/src/ops/reindex/slice.rs b/crates/ratchet-core/src/ops/reindex/slice.rs index fce43db6..508f2efd 100644 --- a/crates/ratchet-core/src/ops/reindex/slice.rs +++ b/crates/ratchet-core/src/ops/reindex/slice.rs @@ -178,7 +178,7 @@ def slice(a): run_reindex_trial(prob, device).unwrap(); } - #[proptest(cases = 16)] + #[proptest(cases = 1)] fn test_slice_cpu(prob: SliceProblem) { let _ = env_logger::builder().is_test(true).try_init(); let device = Device::request_device(DeviceRequest::CPU).unwrap(); diff --git a/crates/ratchet-core/src/shape.rs b/crates/ratchet-core/src/shape.rs index c349cbb5..895e762e 100644 --- a/crates/ratchet-core/src/shape.rs +++ b/crates/ratchet-core/src/shape.rs @@ -144,6 +144,13 @@ impl Shape { } } +impl core::ops::Deref for Shape { + type Target = [usize]; + fn deref(&self) -> &Self::Target { + &self.0 + } +} + impl std::fmt::Debug for Shape { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let mut shape = format!("[{}", self.0.first().unwrap_or(&0)); diff --git a/crates/ratchet-core/src/strides.rs b/crates/ratchet-core/src/strides.rs index 11920ae0..2ca653f9 100644 --- a/crates/ratchet-core/src/strides.rs +++ b/crates/ratchet-core/src/strides.rs @@ -1,4 +1,4 @@ -use std::ops::Index; +use std::ops::{Index, IndexMut, RangeFrom, RangeTo}; use std::slice::Iter; use crate::{rvec, RVec, Shape}; @@ -29,6 +29,10 @@ impl Strides { pub fn rank(&self) -> usize { self.0.len() } + + pub fn as_slice(&self) -> &[isize] { + &self.0 + } } impl std::fmt::Debug for Strides { @@ -41,6 +45,13 @@ impl std::fmt::Debug for Strides { } } +impl core::ops::Deref for Strides { + type Target = [isize]; + fn deref(&self) -> &Self::Target { + &self.0 + } +} + impl Index for Strides { type Output = isize; @@ -49,6 +60,28 @@ impl Index for Strides { } } +impl IndexMut for Strides { + fn index_mut(&mut self, index: usize) -> &mut Self::Output { + &mut self.0[index] + } +} + +impl Index> for Strides { + type Output = [isize]; + + fn index(&self, index: RangeFrom) -> &Self::Output { + &self.0[index] + } +} + +impl Index> for Strides { + type Output = [isize]; + + fn index(&self, index: RangeTo) -> &Self::Output { + &self.0[index] + } +} + impl From<&Shape> for Strides { fn from(shape: &Shape) -> Self { let mut strides = rvec![]; @@ -68,6 +101,12 @@ impl From> for Strides { } } +impl From<&[isize]> for Strides { + fn from(strides: &[isize]) -> Self { + Self(strides.into()) + } +} + impl From<&Strides> for [u32; 3] { fn from(strides: &Strides) -> Self { assert!(strides.0.len() <= 3); @@ -97,6 +136,17 @@ impl From<&Strides> for [u32; 4] { } } +impl From<&Strides> for [usize; 4] { + fn from(strides: &Strides) -> Self { + assert!(strides.0.len() <= 4); + let mut array = [0; 4]; + for (i, &stride) in strides.0.iter().enumerate() { + array[i] = stride as usize; + } + array + } +} + impl From<&Strides> for glam::UVec4 { fn from(strides: &Strides) -> Self { let array: [u32; 4] = strides.into();