Skip to content

Commit

Permalink
Merge pull request #262 from huggingface/feature/cpu-broadcast
Browse files Browse the repository at this point in the history
Feature/cpu broadcast
  • Loading branch information
FL33TW00D authored Nov 13, 2024
2 parents b9d106f + 74a3565 commit 5979334
Show file tree
Hide file tree
Showing 6 changed files with 419 additions and 15 deletions.
137 changes: 132 additions & 5 deletions crates/ratchet-core/src/cpu/reindex.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,137 @@
use super::utils::cpu_store_result;
use crate::{CPUOperation, DType, OperationError, Reindex, Slice, Strides, Tensor, TensorDType};
use super::utils::{
cpu_store_result, TensorIterator,
TensorIterator::{Contiguous, Strided},
};
use crate::{
Broadcast, CPUOperation, DType, OperationError, Reindex, Shape, Slice, Strides, Tensor,
TensorDType,
};
use half::{bf16, f16};

impl CPUOperation for Reindex {
fn apply_cpu(&self, dst: Tensor) -> Result<Tensor, OperationError> {
match self {
Reindex::Slice(s) => s.apply_cpu(dst),
Reindex::Broadcast(b) => b.apply_cpu(dst),
_ => todo!(),
}
}
}
impl CPUOperation for Broadcast {
fn apply_cpu(&self, dst: Tensor) -> Result<Tensor, OperationError> {
match dst.dt() {
DType::F32 => apply_broadcast::<f32>(self, dst),
DType::BF16 => apply_broadcast::<bf16>(self, dst),
DType::F16 => apply_broadcast::<f16>(self, dst),
DType::I32 => apply_broadcast::<i32>(self, dst),
DType::U32 => apply_broadcast::<u32>(self, dst),
_ => todo!(),
}
}
}

fn apply_broadcast<T: TensorDType>(b: &Broadcast, dst: Tensor) -> Result<Tensor, OperationError> {
let result = broadcast(&b.src.to_vec::<T>()?, b.src.shape(), b.to());
cpu_store_result(&dst, &result);
Ok(dst)
}

pub(crate) fn broadcast<T: TensorDType>(src: &[T], src_shape: &Shape, dst_shape: &Shape) -> Vec<T> {
let mut result = vec![T::zero(); dst_shape.numel()];

if src_shape.is_scalar() {
// Life is simple
result.fill(src[0]);
} else if src_shape.is_vector() {
// If from is a vector and the first dimension is the broadcasting dimension
if src_shape[0] > 1 && src_shape[0] == dst_shape[0] {
let chunk_size = result.len() / src_shape.numel();

(0..result.len())
.step_by(chunk_size)
.enumerate()
.for_each(|(i, chunk)| {
result[chunk..chunk + chunk_size].fill(src[i]);
});
} else {
generic_broadcast(src, &mut result, src_shape, dst_shape)
}
} else {
generic_broadcast(src, &mut result, src_shape, dst_shape)
}

result
}

// TODO: Optimize.
// This generic implementation is almost a direct copy from the gpu impl,
// and can definitely be way more performant.
fn generic_broadcast<T: TensorDType>(
src: &[T],
result: &mut [T],
src_shape: &Shape,
dst_shape: &Shape,
) {
// We now know that these will always be len 4, same as gpu impl.
let src_shape = &Shape::promote(src_shape.clone(), 4);
let dst_shape = &Shape::promote(dst_shape.clone(), 4);

let src_strides = &Strides::from(src_shape);
let dst_strides = &Strides::from(dst_shape);

let src_shape: [usize; 4] = src_shape.try_into().unwrap();
let src_strides: [usize; 4] = src_strides.try_into().unwrap();
let dst_strides: [usize; 4] = dst_strides.try_into().unwrap();

fn offset_to_ndindex(offset: usize, strides: [usize; 4]) -> [usize; 4] {
let mut indices = [0; 4];
let mut remaining = offset;

let idx = remaining / strides[0];
indices[0] = idx;
remaining -= idx * strides[0];

let idx = remaining / strides[1];
indices[1] = idx;
remaining -= idx * strides[1];

let idx = remaining / strides[2];
indices[2] = idx;
remaining -= idx * strides[2];

indices[3] = remaining;
indices
}

fn select(a: [usize; 4], b: [usize; 4], t: [bool; 4]) -> [usize; 4] {
let mut result = [0; 4];
result[0] = if t[0] { a[0] } else { b[0] };
result[1] = if t[1] { a[1] } else { b[1] };
result[2] = if t[2] { a[2] } else { b[2] };
result[3] = if t[3] { a[3] } else { b[3] };
result
}

fn nd_index_to_offset(ndindex: [usize; 4], strides: [usize; 4]) -> usize {
ndindex[0] * strides[0]
+ ndindex[1] * strides[1]
+ ndindex[2] * strides[2]
+ ndindex[3] * strides[3]
}

let shape_onedim_lookup: [bool; 4] = [
src_shape[0] != 1,
src_shape[1] != 1,
src_shape[2] != 1,
src_shape[3] != 1,
];
for i in 0..result.len() {
let dst_index = offset_to_ndindex(i, dst_strides);
let src_index = select(dst_index, [0; 4], shape_onedim_lookup);
let src_offset = nd_index_to_offset(src_index, src_strides);
result[i] = src[src_offset]
}
}

impl CPUOperation for Slice {
fn apply_cpu(&self, dst: Tensor) -> Result<Tensor, OperationError> {
Expand Down Expand Up @@ -49,12 +171,17 @@ pub(crate) fn slice<T: TensorDType>(

let mut dst = vec![T::zero(); dst_numel];

for i in 0..dst_numel {
let mut dst_dots = vec![];
for d in 0..dst_shape.len() {
dst_dots.push(dst_shape[d + 1..].iter().product::<usize>().max(1));
}

for i in 0..dst.len() {
let mut src_index = 0;
let mut tmp = i;
for d in 0..dst_shape.len() {
let coord = tmp / dst_shape[d + 1..].iter().product::<usize>().max(1);
tmp %= dst_shape[d + 1..].iter().product::<usize>().max(1);
let coord = tmp / dst_dots[d];
tmp %= dst_dots[d];
src_index += (coord + start[d]) * src_strides[d] as usize;
}
dst[i] = src[src_index];
Expand Down
217 changes: 215 additions & 2 deletions crates/ratchet-core/src/cpu/utils.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,219 @@
use crate::{CPUBuffer, Storage, Tensor};
use bytemuck::NoUninit;
use crate::{CPUBuffer, Shape, Storage, Strides, Tensor};
use bytemuck::{Contiguous, NoUninit};
use std::ops::Range;

pub fn cpu_store_result<T: NoUninit>(dst: &Tensor, data: &[T]) {
dst.update_storage(Storage::CPU(CPUBuffer::from_slice(data, dst.shape())));
}

#[derive(Clone)]
pub enum TensorIterator<'a> {
Contiguous(Range<usize>),
Strided(StridedIterator<'a>),
}

impl<'a> TensorIterator<'a> {
pub fn new(shape: &'a Shape, strides: &'a Strides, offset: usize) -> Self {
let mut block_size: usize = 1;
let mut contiguous_dims: usize = 0;
for (&stride, &dim) in strides.iter().zip(shape.iter()).rev() {
if stride as usize != block_size {
break;
}
block_size *= dim as usize;
contiguous_dims += 1;
}
let index_dims = shape.rank() - contiguous_dims;
if index_dims == 0 {
Self::Contiguous(offset..block_size)
} else {
Self::Strided(StridedIterator::new(&shape, &strides, offset, block_size))
}
}
}

impl<'a> Iterator for TensorIterator<'a> {
type Item = usize;

fn next(&mut self) -> Option<Self::Item> {
match self {
Self::Contiguous(r) => r.next(),
Self::Strided(s) => s.next(),
}
}
}

#[derive(Clone)]
pub struct StridedIterator<'a> {
shape: &'a [usize],
strides: &'a [isize],
next_index: Option<usize>,
multi_index: Vec<usize>,
block_size: usize,
block_step: usize,
}

impl<'a> StridedIterator<'a> {
pub fn new(
shape: &'a [usize],
strides: &'a [isize],
start_offset: usize,
block_len: usize,
) -> Self {
Self {
shape,
strides,
next_index: if shape.iter().product::<usize>() == 0 {
None
} else {
Some(start_offset)
},
multi_index: vec![0; shape.len()],
block_size: block_len,
block_step: 0,
}
}
}

impl<'a> Iterator for StridedIterator<'a> {
type Item = usize;

fn next(&mut self) -> Option<Self::Item> {
let storage_index = match self.next_index {
None => return None,
Some(storage_index) => storage_index,
};

if self.block_size > 1 {
if self.block_step < self.block_size {
self.block_step += 1;
return Some(storage_index + self.block_step - 1);
} else {
self.block_step = 0;
}
}

let mut updated = false;
let mut next_storage_index = storage_index;
for ((multi_i, max_i), stride_i) in self
.multi_index
.iter_mut()
.zip(self.shape.iter())
.zip(self.strides.iter())
.rev()
{
let next_i = *multi_i + 1;
if next_i < *max_i {
*multi_i = next_i;
updated = true;
next_storage_index += *stride_i as usize;
break;
} else {
next_storage_index -= *multi_i * *stride_i as usize;
*multi_i = 0
}
}
self.next_index = if updated {
Some(next_storage_index)
} else {
None
};
Some(storage_index)
}
}

impl<'a> From<(&'a Shape, &'a Strides)> for StridedIterator<'a> {
fn from((shape, strides): (&'a Shape, &'a Strides)) -> Self {
StridedIterator::new(shape.as_slice(), strides.as_slice(), 0, 1)
}
}

impl<'a> From<(&'a Shape, &'a Strides, usize)> for StridedIterator<'a> {
fn from((shape, strides, offset): (&'a Shape, &'a Strides, usize)) -> Self {
StridedIterator::new(shape.as_slice(), strides.as_slice(), offset, 1)
}
}

#[cfg(test)]
mod tests {
use proptest::prelude::*;
use test_strategy::{proptest, Arbitrary};

use crate::{shape, Shape, Strides};

use super::{StridedIterator, TensorIterator};

#[derive(Debug)]
struct IterProblem {
shape: Shape,
offset: usize,
}

impl Arbitrary for IterProblem {
type Parameters = ();
type Strategy = BoxedStrategy<Self>;

fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
let ranges = vec![1..=2, 1..=4, 1..=256, 1..=256];
Shape::arbitrary_with(ranges)
.prop_flat_map(|shape| (Just(shape.clone()), 0..shape.numel()))
.prop_map(|(shape, offset)| IterProblem { shape, offset })
.boxed()
}
}

#[proptest(cases = 16)]
fn test_tensor_iter_contiguous(prob: IterProblem) {
let shape = prob.shape;
let strides = Strides::from(&shape);
let offset = prob.offset;

let iter = TensorIterator::new(&shape, &strides, offset);
assert!(matches!(iter, TensorIterator::Contiguous(_)));

match iter {
TensorIterator::Contiguous(r) => assert_eq!(r, offset..shape.numel()),
_ => unreachable!(),
}
}

#[proptest(cases = 16)]
fn test_tensor_iter_strided(prob: IterProblem) {
let mut shape = prob.shape;
let mut strides = Strides::from(&shape);
strides.transpose();
shape.transpose();
let offset = prob.offset;

let iter = TensorIterator::new(&shape, &strides, offset);
assert!(matches!(iter, TensorIterator::Strided(_)));

match iter {
TensorIterator::Strided(strided_iter) => {
let mut indices: Vec<usize> = strided_iter.collect();
assert_eq!(indices.len(), shape.numel());
let contiguous: Vec<usize> = (offset..shape.numel() + offset).collect();
assert_ne!(indices, contiguous);
indices.sort();
assert_eq!(indices, contiguous);
}
_ => unreachable!(),
}
}

#[test]
fn test_tensor_iter_strided_sanity() {
let mut shape = shape!(2, 4, 3);
let mut strides = Strides::from(&shape);
strides.transpose();
shape.transpose();
let offset = 2;

let iter = TensorIterator::new(&shape, &strides, offset);
let actual: Vec<usize> = iter.collect();
let expected = vec![
2, 5, 8, 11, 3, 6, 9, 12, 4, 7, 10, 13, 14, 17, 20, 23, 15, 18, 21, 24, 16, 19, 22, 25,
];
assert_eq!(actual, expected);
}
}
Loading

0 comments on commit 5979334

Please sign in to comment.