Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/cpu broadcast #262

Merged
merged 7 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 132 additions & 5 deletions crates/ratchet-core/src/cpu/reindex.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,137 @@
use super::utils::cpu_store_result;
use crate::{CPUOperation, DType, OperationError, Reindex, Slice, Strides, Tensor, TensorDType};
use super::utils::{
cpu_store_result, TensorIterator,
TensorIterator::{Contiguous, Strided},
};
use crate::{
Broadcast, CPUOperation, DType, OperationError, Reindex, Shape, Slice, Strides, Tensor,
TensorDType,
};
use half::{bf16, f16};

impl CPUOperation for Reindex {
fn apply_cpu(&self, dst: Tensor) -> Result<Tensor, OperationError> {
match self {
Reindex::Slice(s) => s.apply_cpu(dst),
Reindex::Broadcast(b) => b.apply_cpu(dst),
_ => todo!(),
}
}
}
impl CPUOperation for Broadcast {
fn apply_cpu(&self, dst: Tensor) -> Result<Tensor, OperationError> {
match dst.dt() {
DType::F32 => apply_broadcast::<f32>(self, dst),
DType::BF16 => apply_broadcast::<bf16>(self, dst),
DType::F16 => apply_broadcast::<f16>(self, dst),
DType::I32 => apply_broadcast::<i32>(self, dst),
DType::U32 => apply_broadcast::<u32>(self, dst),
_ => todo!(),
}
}
}

fn apply_broadcast<T: TensorDType>(b: &Broadcast, dst: Tensor) -> Result<Tensor, OperationError> {
let result = broadcast(&b.src.to_vec::<T>()?, b.src.shape(), b.to());
cpu_store_result(&dst, &result);
Ok(dst)
}

pub(crate) fn broadcast<T: TensorDType>(src: &[T], src_shape: &Shape, dst_shape: &Shape) -> Vec<T> {
let mut result = vec![T::zero(); dst_shape.numel()];

if src_shape.is_scalar() {
// Life is simple
result.fill(src[0]);
} else if src_shape.is_vector() {
// If from is a vector and the first dimension is the broadcasting dimension
if src_shape[0] > 1 && src_shape[0] == dst_shape[0] {
let chunk_size = result.len() / src_shape.numel();

(0..result.len())
.step_by(chunk_size)
.enumerate()
.for_each(|(i, chunk)| {
result[chunk..chunk + chunk_size].fill(src[i]);
});
} else {
generic_broadcast(src, &mut result, src_shape, dst_shape)
}
} else {
generic_broadcast(src, &mut result, src_shape, dst_shape)
}

result
}

// TODO: Optimize.
// This generic implementation is almost a direct copy from the gpu impl,
// and can definitely be way more performant.
fn generic_broadcast<T: TensorDType>(
src: &[T],
result: &mut [T],
src_shape: &Shape,
dst_shape: &Shape,
) {
// We now know that these will always be len 4, same as gpu impl.
let src_shape = &Shape::promote(src_shape.clone(), 4);
let dst_shape = &Shape::promote(dst_shape.clone(), 4);

let src_strides = &Strides::from(src_shape);
let dst_strides = &Strides::from(dst_shape);

let src_shape: [usize; 4] = src_shape.try_into().unwrap();
let src_strides: [usize; 4] = src_strides.try_into().unwrap();
let dst_strides: [usize; 4] = dst_strides.try_into().unwrap();

fn offset_to_ndindex(offset: usize, strides: [usize; 4]) -> [usize; 4] {
let mut indices = [0; 4];
let mut remaining = offset;

let idx = remaining / strides[0];
indices[0] = idx;
remaining -= idx * strides[0];

let idx = remaining / strides[1];
indices[1] = idx;
remaining -= idx * strides[1];

let idx = remaining / strides[2];
indices[2] = idx;
remaining -= idx * strides[2];

indices[3] = remaining;
indices
}

fn select(a: [usize; 4], b: [usize; 4], t: [bool; 4]) -> [usize; 4] {
let mut result = [0; 4];
result[0] = if t[0] { a[0] } else { b[0] };
result[1] = if t[1] { a[1] } else { b[1] };
result[2] = if t[2] { a[2] } else { b[2] };
result[3] = if t[3] { a[3] } else { b[3] };
result
}

fn nd_index_to_offset(ndindex: [usize; 4], strides: [usize; 4]) -> usize {
ndindex[0] * strides[0]
+ ndindex[1] * strides[1]
+ ndindex[2] * strides[2]
+ ndindex[3] * strides[3]
}

let shape_onedim_lookup: [bool; 4] = [
src_shape[0] != 1,
src_shape[1] != 1,
src_shape[2] != 1,
src_shape[3] != 1,
];
for i in 0..result.len() {
let dst_index = offset_to_ndindex(i, dst_strides);
let src_index = select(dst_index, [0; 4], shape_onedim_lookup);
let src_offset = nd_index_to_offset(src_index, src_strides);
result[i] = src[src_offset]
}
}

impl CPUOperation for Slice {
fn apply_cpu(&self, dst: Tensor) -> Result<Tensor, OperationError> {
Expand Down Expand Up @@ -49,12 +171,17 @@ pub(crate) fn slice<T: TensorDType>(

let mut dst = vec![T::zero(); dst_numel];

for i in 0..dst_numel {
let mut dst_dots = vec![];
for d in 0..dst_shape.len() {
dst_dots.push(dst_shape[d + 1..].iter().product::<usize>().max(1));
}

for i in 0..dst.len() {
let mut src_index = 0;
let mut tmp = i;
for d in 0..dst_shape.len() {
let coord = tmp / dst_shape[d + 1..].iter().product::<usize>().max(1);
tmp %= dst_shape[d + 1..].iter().product::<usize>().max(1);
let coord = tmp / dst_dots[d];
tmp %= dst_dots[d];
src_index += (coord + start[d]) * src_strides[d] as usize;
}
dst[i] = src[src_index];
Expand Down
217 changes: 215 additions & 2 deletions crates/ratchet-core/src/cpu/utils.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,219 @@
use crate::{CPUBuffer, Storage, Tensor};
use bytemuck::NoUninit;
use crate::{CPUBuffer, Shape, Storage, Strides, Tensor};
use bytemuck::{Contiguous, NoUninit};
use std::ops::Range;

pub fn cpu_store_result<T: NoUninit>(dst: &Tensor, data: &[T]) {
dst.update_storage(Storage::CPU(CPUBuffer::from_slice(data, dst.shape())));
}

#[derive(Clone)]
pub enum TensorIterator<'a> {
Contiguous(Range<usize>),
Strided(StridedIterator<'a>),
}

impl<'a> TensorIterator<'a> {
pub fn new(shape: &'a Shape, strides: &'a Strides, offset: usize) -> Self {
let mut block_size: usize = 1;
let mut contiguous_dims: usize = 0;
for (&stride, &dim) in strides.iter().zip(shape.iter()).rev() {
if stride as usize != block_size {
break;
}
block_size *= dim as usize;
contiguous_dims += 1;
}
let index_dims = shape.rank() - contiguous_dims;
if index_dims == 0 {
Self::Contiguous(offset..block_size)
} else {
Self::Strided(StridedIterator::new(&shape, &strides, offset, block_size))
}
}
}

impl<'a> Iterator for TensorIterator<'a> {
type Item = usize;

fn next(&mut self) -> Option<Self::Item> {
match self {
Self::Contiguous(r) => r.next(),
Self::Strided(s) => s.next(),
}
}
}

#[derive(Clone)]
pub struct StridedIterator<'a> {
shape: &'a [usize],
strides: &'a [isize],
next_index: Option<usize>,
multi_index: Vec<usize>,
block_size: usize,
block_step: usize,
}

impl<'a> StridedIterator<'a> {
pub fn new(
shape: &'a [usize],
strides: &'a [isize],
start_offset: usize,
block_len: usize,
) -> Self {
Self {
shape,
strides,
next_index: if shape.iter().product::<usize>() == 0 {
None
} else {
Some(start_offset)
},
multi_index: vec![0; shape.len()],
block_size: block_len,
block_step: 0,
}
}
}

impl<'a> Iterator for StridedIterator<'a> {
type Item = usize;

fn next(&mut self) -> Option<Self::Item> {
let storage_index = match self.next_index {
None => return None,
Some(storage_index) => storage_index,
};

if self.block_size > 1 {
if self.block_step < self.block_size {
self.block_step += 1;
return Some(storage_index + self.block_step - 1);
} else {
self.block_step = 0;
}
}

let mut updated = false;
let mut next_storage_index = storage_index;
for ((multi_i, max_i), stride_i) in self
.multi_index
.iter_mut()
.zip(self.shape.iter())
.zip(self.strides.iter())
.rev()
{
let next_i = *multi_i + 1;
if next_i < *max_i {
*multi_i = next_i;
updated = true;
next_storage_index += *stride_i as usize;
break;
} else {
next_storage_index -= *multi_i * *stride_i as usize;
*multi_i = 0
}
}
self.next_index = if updated {
Some(next_storage_index)
} else {
None
};
Some(storage_index)
}
}

impl<'a> From<(&'a Shape, &'a Strides)> for StridedIterator<'a> {
fn from((shape, strides): (&'a Shape, &'a Strides)) -> Self {
StridedIterator::new(shape.as_slice(), strides.as_slice(), 0, 1)
}
}

impl<'a> From<(&'a Shape, &'a Strides, usize)> for StridedIterator<'a> {
fn from((shape, strides, offset): (&'a Shape, &'a Strides, usize)) -> Self {
StridedIterator::new(shape.as_slice(), strides.as_slice(), offset, 1)
}
}

#[cfg(test)]
mod tests {
use proptest::prelude::*;
use test_strategy::{proptest, Arbitrary};

use crate::{shape, Shape, Strides};

use super::{StridedIterator, TensorIterator};

#[derive(Debug)]
struct IterProblem {
shape: Shape,
offset: usize,
}

impl Arbitrary for IterProblem {
type Parameters = ();
type Strategy = BoxedStrategy<Self>;

fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
let ranges = vec![1..=2, 1..=4, 1..=256, 1..=256];
Shape::arbitrary_with(ranges)
.prop_flat_map(|shape| (Just(shape.clone()), 0..shape.numel()))
.prop_map(|(shape, offset)| IterProblem { shape, offset })
.boxed()
}
}

#[proptest(cases = 16)]
fn test_tensor_iter_contiguous(prob: IterProblem) {
let shape = prob.shape;
let strides = Strides::from(&shape);
let offset = prob.offset;

let iter = TensorIterator::new(&shape, &strides, offset);
assert!(matches!(iter, TensorIterator::Contiguous(_)));

match iter {
TensorIterator::Contiguous(r) => assert_eq!(r, offset..shape.numel()),
_ => unreachable!(),
}
}

#[proptest(cases = 16)]
fn test_tensor_iter_strided(prob: IterProblem) {
let mut shape = prob.shape;
let mut strides = Strides::from(&shape);
strides.transpose();
shape.transpose();
let offset = prob.offset;

let iter = TensorIterator::new(&shape, &strides, offset);
assert!(matches!(iter, TensorIterator::Strided(_)));

match iter {
TensorIterator::Strided(strided_iter) => {
let mut indices: Vec<usize> = strided_iter.collect();
assert_eq!(indices.len(), shape.numel());
let contiguous: Vec<usize> = (offset..shape.numel() + offset).collect();
assert_ne!(indices, contiguous);
indices.sort();
assert_eq!(indices, contiguous);
}
_ => unreachable!(),
}
}

#[test]
fn test_tensor_iter_strided_sanity() {
let mut shape = shape!(2, 4, 3);
let mut strides = Strides::from(&shape);
strides.transpose();
shape.transpose();
let offset = 2;

let iter = TensorIterator::new(&shape, &strides, offset);
let actual: Vec<usize> = iter.collect();
let expected = vec![
2, 5, 8, 11, 3, 6, 9, 12, 4, 7, 10, 13, 14, 17, 20, 23, 15, 18, 21, 24, 16, 19, 22, 25,
];
assert_eq!(actual, expected);
}
}
Loading
Loading