Skip to content

Commit

Permalink
chore: tidy up broadcast
Browse files Browse the repository at this point in the history
  • Loading branch information
ivarflakstad committed Nov 13, 2024
1 parent 72975b4 commit dad3300
Showing 1 changed file with 40 additions and 34 deletions.
74 changes: 40 additions & 34 deletions crates/ratchet-core/src/cpu/reindex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,24 +36,6 @@ fn apply_broadcast<T: TensorDType>(b: &Broadcast, dst: Tensor) -> Result<Tensor,
Ok(dst)
}

fn offset_to_ndindex(offset: usize, strides: &[usize]) -> Vec<usize> {
let mut indices = vec![0; strides.len()];
let mut remaining = offset;

for i in 0..strides.len() - 1 {
let stride = strides[i];
let idx = remaining / stride;
indices[i] = idx;
remaining -= idx * stride;
}
indices[strides.len() - 1] = remaining;
indices
}

fn nd_index_to_offset(ndindex: &[usize], strides: &[usize]) -> usize {
ndindex.iter().zip(strides.iter()).map(|(x, y)| x * y).sum()
}

pub(crate) fn broadcast<T: TensorDType>(src: &[T], src_shape: &Shape, dst_shape: &Shape) -> Vec<T> {
let src_strides = Strides::from(src_shape);
let dst_strides = Strides::from(dst_shape);
Expand All @@ -79,29 +61,53 @@ pub(crate) fn broadcast<T: TensorDType>(src: &[T], src_shape: &Shape, dst_shape:
result[chunk..chunk + chunk_size].fill(src[i]);
});
} else {
for i in 0..result.len() {
let dst_index = offset_to_ndindex(i, &dst_strides);
let src_index: Vec<usize> = (0..src_shape.len())
.map(|x| if src_shape[x] == 1 { 0 } else { dst_index[x] })
.collect();
let src_offset = nd_index_to_offset(&src_index, &src_strides);
result[i] = src[src_offset]
}
generic_broadcast(src, &mut result, src_shape, &src_strides, &dst_strides)
}
} else {
for i in 0..result.len() {
let dst_index = offset_to_ndindex(i, &dst_strides);
let src_index: Vec<usize> = (0..src_shape.len())
.map(|x| if src_shape[x] == 1 { 0 } else { dst_index[x] })
.collect();
let src_offset = nd_index_to_offset(&src_index, &src_strides);
result[i] = src[src_offset]
}
generic_broadcast(src, &mut result, src_shape, &src_strides, &dst_strides)
}

result
}

// TODO: Optimize.
// This generic implementation is almost a direct copy from the gpu impl,
// and can definitely be way more performant.
fn generic_broadcast<T: TensorDType>(
src: &[T],
result: &mut [T],
src_shape: &[usize],
src_strides: &[usize],
dst_strides: &[usize],
) {
fn offset_to_ndindex(offset: usize, strides: &[usize]) -> Vec<usize> {
let mut indices = vec![0; strides.len()];
let mut remaining = offset;

for i in 0..strides.len() - 1 {
let stride = strides[i];
let idx = remaining / stride;
indices[i] = idx;
remaining -= idx * stride;
}
indices[strides.len() - 1] = remaining;
indices
}

fn nd_index_to_offset(ndindex: &[usize], strides: &[usize]) -> usize {
ndindex.iter().zip(strides.iter()).map(|(x, y)| x * y).sum()
}

for i in 0..result.len() {
let dst_index = offset_to_ndindex(i, &dst_strides);
let src_index: Vec<usize> = (0..src_shape.len())
.map(|x| if src_shape[x] == 1 { 0 } else { dst_index[x] })
.collect();
let src_offset = nd_index_to_offset(&src_index, &src_strides);
result[i] = src[src_offset]
}
}

impl CPUOperation for Slice {
fn apply_cpu(&self, dst: Tensor) -> Result<Tensor, OperationError> {
match dst.dt() {
Expand Down

0 comments on commit dad3300

Please sign in to comment.