Skip to content

Commit

Permalink
Merge pull request #263 from huggingface/feature/cpu-permute
Browse files Browse the repository at this point in the history
Feature/cpu permute
  • Loading branch information
FL33TW00D authored Nov 15, 2024
2 parents 5979334 + a4932d5 commit 560ccea
Show file tree
Hide file tree
Showing 2 changed files with 151 additions and 83 deletions.
220 changes: 141 additions & 79 deletions crates/ratchet-core/src/cpu/reindex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,137 @@ use super::utils::{
TensorIterator::{Contiguous, Strided},
};
use crate::{
Broadcast, CPUOperation, DType, OperationError, Reindex, Shape, Slice, Strides, Tensor,
TensorDType,
Broadcast, CPUOperation, DType, OperationError, Permute, Reindex, Shape, Slice, Strides,
Tensor, TensorDType,
};
use half::{bf16, f16};
use ndarray::ShapeBuilder;
use std::ops::Range;

impl CPUOperation for Reindex {
fn apply_cpu(&self, dst: Tensor) -> Result<Tensor, OperationError> {
match self {
Reindex::Permute(p) => p.apply_cpu(dst),
Reindex::Slice(s) => s.apply_cpu(dst),
Reindex::Broadcast(b) => b.apply_cpu(dst),
_ => todo!(),
}
}
}

impl CPUOperation for Permute {
fn apply_cpu(&self, dst: Tensor) -> Result<Tensor, OperationError> {
match dst.dt() {
DType::F32 => apply_permute::<f32>(self, dst),
DType::BF16 => apply_permute::<bf16>(self, dst),
DType::F16 => apply_permute::<f16>(self, dst),
DType::I32 => apply_permute::<i32>(self, dst),
DType::U32 => apply_permute::<u32>(self, dst),
_ => todo!(),
}
}
}

fn apply_permute<T: TensorDType>(p: &Permute, dst: Tensor) -> Result<Tensor, OperationError> {
let perm: [usize; 4] = p.promote().try_into().unwrap();
let Permute { src, dims } = p;
let result = permute(&src.to_vec::<T>()?, src.shape(), dst.shape(), perm);
cpu_store_result(&dst, &result);
Ok(dst)
}

// TODO: Optimize.
// This generic implementation is almost a direct copy from the gpu impl,
// and can definitely be way more performant.
fn permute<T: TensorDType>(
src: &[T],
src_shape: &Shape,
dst_shape: &Shape,
perm: [usize; 4],
) -> Vec<T> {
let mut result = vec![T::zero(); src_shape.numel()];

// We now know that these will always be len 4, same as gpu impl.
let src_shape = &Shape::promote(src_shape.clone(), 4);
let dst_shape = &Shape::promote(dst_shape.clone(), 4);

let src_strides = &Strides::from(src_shape);
let dst_strides = &Strides::from(dst_shape);

let src_shape: [usize; 4] = src_shape.try_into().unwrap();
let src_strides: [usize; 4] = src_strides.try_into().unwrap();
let dst_strides: [usize; 4] = dst_strides.try_into().unwrap();

for i in 0..result.len() {
let dst_index = offset_to_ndindex(i, dst_strides);
let mut src_index = [0; 4];
src_index[perm[0]] = dst_index[0];
src_index[perm[1]] = dst_index[1];
src_index[perm[2]] = dst_index[2];
src_index[perm[3]] = dst_index[3];
let src_offset = nd_index_to_offset(src_index, src_strides);
result[i] = src[src_offset]
}
result
}

impl CPUOperation for Slice {
fn apply_cpu(&self, dst: Tensor) -> Result<Tensor, OperationError> {
match dst.dt() {
DType::F32 => apply_slice::<f32>(self, dst),
DType::BF16 => apply_slice::<bf16>(self, dst),
DType::F16 => apply_slice::<f16>(self, dst),
DType::I32 => apply_slice::<i32>(self, dst),
DType::U32 => apply_slice::<u32>(self, dst),
_ => todo!(),
}
}
}

fn apply_slice<T: TensorDType>(s: &Slice, dst: Tensor) -> Result<Tensor, OperationError> {
let (start, stop): (Vec<_>, Vec<_>) = s.indices().iter().map(|r| (r.start, r.end)).unzip();
let result = slice(&s.src.to_vec::<T>()?, s.src.strides(), &start, &stop);

cpu_store_result(&dst, &result);
Ok(dst)
}

pub(crate) fn slice<T: TensorDType>(
src: &[T],
src_strides: &Strides,
start: &[usize],
stop: &[usize],
) -> Vec<T> {
assert!(start.len() == stop.len());
assert!(start.len() == src_strides.rank());
start.iter().zip(stop.iter()).for_each(|(s, t)| {
assert!(s < t);
});

let dst_shape: Vec<usize> = stop.iter().zip(start.iter()).map(|(s, t)| s - t).collect();
let dst_numel: usize = dst_shape.iter().product();

let mut dst = vec![T::zero(); dst_numel];

let mut dst_dots = vec![];
for d in 0..dst_shape.len() {
dst_dots.push(dst_shape[d + 1..].iter().product::<usize>().max(1));
}

for i in 0..dst.len() {
let mut src_index = 0;
let mut tmp = i;
for d in 0..dst_shape.len() {
let coord = tmp / dst_dots[d];
tmp %= dst_dots[d];
src_index += (coord + start[d]) * src_strides[d] as usize;
}
dst[i] = src[src_index];
}

dst
}

impl CPUOperation for Broadcast {
fn apply_cpu(&self, dst: Tensor) -> Result<Tensor, OperationError> {
match dst.dt() {
Expand Down Expand Up @@ -83,26 +200,6 @@ fn generic_broadcast<T: TensorDType>(
let src_strides: [usize; 4] = src_strides.try_into().unwrap();
let dst_strides: [usize; 4] = dst_strides.try_into().unwrap();

fn offset_to_ndindex(offset: usize, strides: [usize; 4]) -> [usize; 4] {
let mut indices = [0; 4];
let mut remaining = offset;

let idx = remaining / strides[0];
indices[0] = idx;
remaining -= idx * strides[0];

let idx = remaining / strides[1];
indices[1] = idx;
remaining -= idx * strides[1];

let idx = remaining / strides[2];
indices[2] = idx;
remaining -= idx * strides[2];

indices[3] = remaining;
indices
}

fn select(a: [usize; 4], b: [usize; 4], t: [bool; 4]) -> [usize; 4] {
let mut result = [0; 4];
result[0] = if t[0] { a[0] } else { b[0] };
Expand All @@ -112,13 +209,6 @@ fn generic_broadcast<T: TensorDType>(
result
}

fn nd_index_to_offset(ndindex: [usize; 4], strides: [usize; 4]) -> usize {
ndindex[0] * strides[0]
+ ndindex[1] * strides[1]
+ ndindex[2] * strides[2]
+ ndindex[3] * strides[3]
}

let shape_onedim_lookup: [bool; 4] = [
src_shape[0] != 1,
src_shape[1] != 1,
Expand All @@ -133,59 +223,31 @@ fn generic_broadcast<T: TensorDType>(
}
}

impl CPUOperation for Slice {
fn apply_cpu(&self, dst: Tensor) -> Result<Tensor, OperationError> {
match dst.dt() {
DType::F32 => apply_slice::<f32>(self, dst),
DType::BF16 => apply_slice::<bf16>(self, dst),
DType::F16 => apply_slice::<f16>(self, dst),
DType::I32 => apply_slice::<i32>(self, dst),
DType::U32 => apply_slice::<u32>(self, dst),
_ => todo!(),
}
}
}
#[inline]
fn offset_to_ndindex(offset: usize, strides: [usize; 4]) -> [usize; 4] {
let mut indices = [0; 4];
let mut remaining = offset;

fn apply_slice<T: TensorDType>(s: &Slice, dst: Tensor) -> Result<Tensor, OperationError> {
let (start, stop): (Vec<_>, Vec<_>) = s.indices().iter().map(|r| (r.start, r.end)).unzip();
let result = slice(&s.src.to_vec::<T>()?, s.src.strides(), &start, &stop);

cpu_store_result(&dst, &result);
Ok(dst)
}

pub(crate) fn slice<T: TensorDType>(
src: &[T],
src_strides: &Strides,
start: &[usize],
stop: &[usize],
) -> Vec<T> {
assert!(start.len() == stop.len());
assert!(start.len() == src_strides.rank());
start.iter().zip(stop.iter()).for_each(|(s, t)| {
assert!(s < t);
});

let dst_shape: Vec<usize> = stop.iter().zip(start.iter()).map(|(s, t)| s - t).collect();
let dst_numel: usize = dst_shape.iter().product();
let idx = remaining / strides[0];
indices[0] = idx;
remaining -= idx * strides[0];

let mut dst = vec![T::zero(); dst_numel];
let idx = remaining / strides[1];
indices[1] = idx;
remaining -= idx * strides[1];

let mut dst_dots = vec![];
for d in 0..dst_shape.len() {
dst_dots.push(dst_shape[d + 1..].iter().product::<usize>().max(1));
}
let idx = remaining / strides[2];
indices[2] = idx;
remaining -= idx * strides[2];

for i in 0..dst.len() {
let mut src_index = 0;
let mut tmp = i;
for d in 0..dst_shape.len() {
let coord = tmp / dst_dots[d];
tmp %= dst_dots[d];
src_index += (coord + start[d]) * src_strides[d] as usize;
}
dst[i] = src[src_index];
}
indices[3] = remaining;
indices
}

dst
#[inline]
fn nd_index_to_offset(ndindex: [usize; 4], strides: [usize; 4]) -> usize {
ndindex[0] * strides[0]
+ ndindex[1] * strides[1]
+ ndindex[2] * strides[2]
+ ndindex[3] * strides[3]
}
14 changes: 10 additions & 4 deletions crates/ratchet-core/src/ops/reindex/permute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,8 @@ def permute(a):
run_py_prg(prg.to_string(), &[a], &[], a.dt())
}

fn run_reindex_trial(prob: PermuteProblem) -> anyhow::Result<()> {
fn run_reindex_trial(prob: PermuteProblem, device: Device) -> anyhow::Result<()> {
let PermuteProblem { op } = prob;
let device = Device::request_device(DeviceRequest::GPU).unwrap();
let a = op.src.clone();

let a_gpu = a.to(&device)?;
Expand All @@ -125,7 +124,14 @@ def permute(a):
}

#[proptest(cases = 16)]
fn test_permute(prob: PermuteProblem) {
run_reindex_trial(prob).unwrap();
fn test_permute_gpu(prob: PermuteProblem) {
let device = Device::request_device(DeviceRequest::GPU).unwrap();
run_reindex_trial(prob, device).unwrap();
}

#[proptest(cases = 16)]
fn test_permute_cpu(prob: PermuteProblem) {
let device = Device::request_device(DeviceRequest::CPU).unwrap();
run_reindex_trial(prob, device).unwrap();
}
}

0 comments on commit 560ccea

Please sign in to comment.