Skip to content

Commit

Permalink
chore: restructure reindex ops
Browse files Browse the repository at this point in the history
  • Loading branch information
ivarflakstad committed Nov 13, 2024
1 parent 7d6311b commit a4932d5
Showing 1 changed file with 122 additions and 120 deletions.
242 changes: 122 additions & 120 deletions crates/ratchet-core/src/cpu/reindex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,124 +21,6 @@ impl CPUOperation for Reindex {
}
}

impl CPUOperation for Broadcast {
fn apply_cpu(&self, dst: Tensor) -> Result<Tensor, OperationError> {
match dst.dt() {
DType::F32 => apply_broadcast::<f32>(self, dst),
DType::BF16 => apply_broadcast::<bf16>(self, dst),
DType::F16 => apply_broadcast::<f16>(self, dst),
DType::I32 => apply_broadcast::<i32>(self, dst),
DType::U32 => apply_broadcast::<u32>(self, dst),
_ => todo!(),
}
}
}

fn apply_broadcast<T: TensorDType>(b: &Broadcast, dst: Tensor) -> Result<Tensor, OperationError> {
let result = broadcast(&b.src.to_vec::<T>()?, b.src.shape(), b.to());
cpu_store_result(&dst, &result);
Ok(dst)
}

pub(crate) fn broadcast<T: TensorDType>(src: &[T], src_shape: &Shape, dst_shape: &Shape) -> Vec<T> {
let mut result = vec![T::zero(); dst_shape.numel()];

if src_shape.is_scalar() {
// Life is simple
result.fill(src[0]);
} else if src_shape.is_vector() {
// If from is a vector and the first dimension is the broadcasting dimension
if src_shape[0] > 1 && src_shape[0] == dst_shape[0] {
let chunk_size = result.len() / src_shape.numel();

(0..result.len())
.step_by(chunk_size)
.enumerate()
.for_each(|(i, chunk)| {
result[chunk..chunk + chunk_size].fill(src[i]);
});
} else {
generic_broadcast(src, &mut result, src_shape, dst_shape)
}
} else {
generic_broadcast(src, &mut result, src_shape, dst_shape)
}

result
}

#[inline]
fn offset_to_ndindex(offset: usize, strides: [usize; 4]) -> [usize; 4] {
let mut indices = [0; 4];
let mut remaining = offset;

let idx = remaining / strides[0];
indices[0] = idx;
remaining -= idx * strides[0];

let idx = remaining / strides[1];
indices[1] = idx;
remaining -= idx * strides[1];

let idx = remaining / strides[2];
indices[2] = idx;
remaining -= idx * strides[2];

indices[3] = remaining;
indices
}

#[inline]
fn nd_index_to_offset(ndindex: [usize; 4], strides: [usize; 4]) -> usize {
ndindex[0] * strides[0]
+ ndindex[1] * strides[1]
+ ndindex[2] * strides[2]
+ ndindex[3] * strides[3]
}

// TODO: Optimize.
// This generic implementation is almost a direct copy from the gpu impl,
// and can definitely be way more performant.
fn generic_broadcast<T: TensorDType>(
src: &[T],
result: &mut [T],
src_shape: &Shape,
dst_shape: &Shape,
) {
// We now know that these will always be len 4, same as gpu impl.
let src_shape = &Shape::promote(src_shape.clone(), 4);
let dst_shape = &Shape::promote(dst_shape.clone(), 4);

let src_strides = &Strides::from(src_shape);
let dst_strides = &Strides::from(dst_shape);

let src_shape: [usize; 4] = src_shape.try_into().unwrap();
let src_strides: [usize; 4] = src_strides.try_into().unwrap();
let dst_strides: [usize; 4] = dst_strides.try_into().unwrap();

fn select(a: [usize; 4], b: [usize; 4], t: [bool; 4]) -> [usize; 4] {
let mut result = [0; 4];
result[0] = if t[0] { a[0] } else { b[0] };
result[1] = if t[1] { a[1] } else { b[1] };
result[2] = if t[2] { a[2] } else { b[2] };
result[3] = if t[3] { a[3] } else { b[3] };
result
}

let shape_onedim_lookup: [bool; 4] = [
src_shape[0] != 1,
src_shape[1] != 1,
src_shape[2] != 1,
src_shape[3] != 1,
];
for i in 0..result.len() {
let dst_index = offset_to_ndindex(i, dst_strides);
let src_index = select(dst_index, [0; 4], shape_onedim_lookup);
let src_offset = nd_index_to_offset(src_index, src_strides);
result[i] = src[src_offset]
}
}

impl CPUOperation for Permute {
fn apply_cpu(&self, dst: Tensor) -> Result<Tensor, OperationError> {
match dst.dt() {
Expand All @@ -160,6 +42,9 @@ fn apply_permute<T: TensorDType>(p: &Permute, dst: Tensor) -> Result<Tensor, Ope
Ok(dst)
}

// TODO: Optimize.
// This generic implementation is almost a direct copy from the gpu impl,
// and can definitely be way more performant.
fn permute<T: TensorDType>(
src: &[T],
src_shape: &Shape,
Expand Down Expand Up @@ -189,8 +74,7 @@ fn permute<T: TensorDType>(
let src_offset = nd_index_to_offset(src_index, src_strides);
result[i] = src[src_offset]
}

return result;
result
}

impl CPUOperation for Slice {
Expand Down Expand Up @@ -249,3 +133,121 @@ pub(crate) fn slice<T: TensorDType>(

dst
}

impl CPUOperation for Broadcast {
fn apply_cpu(&self, dst: Tensor) -> Result<Tensor, OperationError> {
match dst.dt() {
DType::F32 => apply_broadcast::<f32>(self, dst),
DType::BF16 => apply_broadcast::<bf16>(self, dst),
DType::F16 => apply_broadcast::<f16>(self, dst),
DType::I32 => apply_broadcast::<i32>(self, dst),
DType::U32 => apply_broadcast::<u32>(self, dst),
_ => todo!(),
}
}
}

fn apply_broadcast<T: TensorDType>(b: &Broadcast, dst: Tensor) -> Result<Tensor, OperationError> {
let result = broadcast(&b.src.to_vec::<T>()?, b.src.shape(), b.to());
cpu_store_result(&dst, &result);
Ok(dst)
}

pub(crate) fn broadcast<T: TensorDType>(src: &[T], src_shape: &Shape, dst_shape: &Shape) -> Vec<T> {
let mut result = vec![T::zero(); dst_shape.numel()];

if src_shape.is_scalar() {
// Life is simple
result.fill(src[0]);
} else if src_shape.is_vector() {
// If from is a vector and the first dimension is the broadcasting dimension
if src_shape[0] > 1 && src_shape[0] == dst_shape[0] {
let chunk_size = result.len() / src_shape.numel();

(0..result.len())
.step_by(chunk_size)
.enumerate()
.for_each(|(i, chunk)| {
result[chunk..chunk + chunk_size].fill(src[i]);
});
} else {
generic_broadcast(src, &mut result, src_shape, dst_shape)
}
} else {
generic_broadcast(src, &mut result, src_shape, dst_shape)
}

result
}

// TODO: Optimize.
// This generic implementation is almost a direct copy from the gpu impl,
// and can definitely be way more performant.
fn generic_broadcast<T: TensorDType>(
src: &[T],
result: &mut [T],
src_shape: &Shape,
dst_shape: &Shape,
) {
// We now know that these will always be len 4, same as gpu impl.
let src_shape = &Shape::promote(src_shape.clone(), 4);
let dst_shape = &Shape::promote(dst_shape.clone(), 4);

let src_strides = &Strides::from(src_shape);
let dst_strides = &Strides::from(dst_shape);

let src_shape: [usize; 4] = src_shape.try_into().unwrap();
let src_strides: [usize; 4] = src_strides.try_into().unwrap();
let dst_strides: [usize; 4] = dst_strides.try_into().unwrap();

fn select(a: [usize; 4], b: [usize; 4], t: [bool; 4]) -> [usize; 4] {
let mut result = [0; 4];
result[0] = if t[0] { a[0] } else { b[0] };
result[1] = if t[1] { a[1] } else { b[1] };
result[2] = if t[2] { a[2] } else { b[2] };
result[3] = if t[3] { a[3] } else { b[3] };
result
}

let shape_onedim_lookup: [bool; 4] = [
src_shape[0] != 1,
src_shape[1] != 1,
src_shape[2] != 1,
src_shape[3] != 1,
];
for i in 0..result.len() {
let dst_index = offset_to_ndindex(i, dst_strides);
let src_index = select(dst_index, [0; 4], shape_onedim_lookup);
let src_offset = nd_index_to_offset(src_index, src_strides);
result[i] = src[src_offset]
}
}

#[inline]
fn offset_to_ndindex(offset: usize, strides: [usize; 4]) -> [usize; 4] {
let mut indices = [0; 4];
let mut remaining = offset;

let idx = remaining / strides[0];
indices[0] = idx;
remaining -= idx * strides[0];

let idx = remaining / strides[1];
indices[1] = idx;
remaining -= idx * strides[1];

let idx = remaining / strides[2];
indices[2] = idx;
remaining -= idx * strides[2];

indices[3] = remaining;
indices
}

#[inline]
fn nd_index_to_offset(ndindex: [usize; 4], strides: [usize; 4]) -> usize {
ndindex[0] * strides[0]
+ ndindex[1] * strides[1]
+ ndindex[2] * strides[2]
+ ndindex[3] * strides[3]
}

0 comments on commit a4932d5

Please sign in to comment.