Skip to content

Commit

Permalink
Merge pull request #261 from huggingface/feature/cpu-slice
Browse files Browse the repository at this point in the history
Feature/cpu slice
  • Loading branch information
FL33TW00D authored Nov 4, 2024
2 parents a0bd0f1 + 08900c2 commit b9d106f
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 35 deletions.
3 changes: 2 additions & 1 deletion crates/ratchet-core/src/cpu/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
mod binary;
pub mod gemm;
pub mod reindex;
pub mod rope;
mod unary;
mod utils;
Expand All @@ -24,7 +25,7 @@ pub fn apply_operation(op: LazyOp, dst: Tensor) -> Result<Tensor, OperationError
LazyOp::Softmax(_s) => todo!(),
LazyOp::RoPE(r) => cpu_rope(r, dst),
LazyOp::Unary(u) => u.apply_cpu(dst),
LazyOp::Reindex(_r) => todo!(),
LazyOp::Reindex(r) => r.apply_cpu(dst),
LazyOp::Concat(c) => cpu_concat(c, dst),
LazyOp::Norm(_n) => todo!(),
LazyOp::Conv(_c) => todo!(),
Expand Down
64 changes: 64 additions & 0 deletions crates/ratchet-core/src/cpu/reindex.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
use super::utils::cpu_store_result;
use crate::{CPUOperation, DType, OperationError, Reindex, Slice, Strides, Tensor, TensorDType};
use half::{bf16, f16};

impl CPUOperation for Reindex {
fn apply_cpu(&self, dst: Tensor) -> Result<Tensor, OperationError> {
match self {
Reindex::Slice(s) => s.apply_cpu(dst),
_ => todo!(),
}
}
}

impl CPUOperation for Slice {
fn apply_cpu(&self, dst: Tensor) -> Result<Tensor, OperationError> {
match dst.dt() {
DType::F32 => apply_slice::<f32>(self, dst),
DType::BF16 => apply_slice::<bf16>(self, dst),
DType::F16 => apply_slice::<f16>(self, dst),
DType::I32 => apply_slice::<i32>(self, dst),
DType::U32 => apply_slice::<u32>(self, dst),
_ => todo!(),
}
}
}

fn apply_slice<T: TensorDType>(s: &Slice, dst: Tensor) -> Result<Tensor, OperationError> {
let (start, stop): (Vec<_>, Vec<_>) = s.indices().iter().map(|r| (r.start, r.end)).unzip();
let result = slice(&s.src.to_vec::<T>()?, s.src.strides(), &start, &stop);

cpu_store_result(&dst, &result);
Ok(dst)
}

pub(crate) fn slice<T: TensorDType>(
src: &[T],
src_strides: &Strides,
start: &[usize],
stop: &[usize],
) -> Vec<T> {
assert!(start.len() == stop.len());
assert!(start.len() == src_strides.rank());
start.iter().zip(stop.iter()).for_each(|(s, t)| {
assert!(s < t);
});

let dst_shape: Vec<usize> = stop.iter().zip(start.iter()).map(|(s, t)| s - t).collect();
let dst_numel: usize = dst_shape.iter().product();

let mut dst = vec![T::zero(); dst_numel];

for i in 0..dst_numel {
let mut src_index = 0;
let mut tmp = i;
for d in 0..dst_shape.len() {
let coord = tmp / dst_shape[d + 1..].iter().product::<usize>().max(1);
tmp %= dst_shape[d + 1..].iter().product::<usize>().max(1);
src_index += (coord + start[d]) * src_strides[d] as usize;
}
dst[i] = src[src_index];
}

dst
}
28 changes: 1 addition & 27 deletions crates/ratchet-core/src/cpu/rope.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::{
concat,
cpu::{cpu_store_result, gemm::gemm},
cpu::{cpu_store_result, gemm::gemm, reindex::slice},
shape, DType, OperationError, RoPE, Shape, Strides, Tensor,
};
use anyhow::anyhow;
Expand Down Expand Up @@ -61,32 +61,6 @@ fn compute_theta(
Ok(theta)
}

fn slice(src: &[f32], src_strides: &Strides, start: &[usize], stop: &[usize]) -> Vec<f32> {
assert!(start.len() == stop.len());
assert!(start.len() == src_strides.rank());
start.iter().zip(stop.iter()).for_each(|(s, t)| {
assert!(s < t);
});

let dst_shape: Vec<usize> = stop.iter().zip(start.iter()).map(|(s, t)| s - t).collect();
let dst_numel: usize = dst_shape.iter().product();

let mut dst = vec![0.0; dst_numel];

for i in 0..dst_numel {
let mut src_index = 0;
let mut tmp = i;
for d in 0..dst_shape.len() {
let coord = tmp / dst_shape[d + 1..].iter().product::<usize>().max(1);
tmp %= dst_shape[d + 1..].iter().product::<usize>().max(1);
src_index += (coord + start[d]) * src_strides[d] as usize;
}
dst[i] = src[src_index];
}

dst
}

fn rope(
src: Vec<f32>,
shape: &Shape,
Expand Down
3 changes: 0 additions & 3 deletions crates/ratchet-core/src/cpu/slice.rs

This file was deleted.

15 changes: 11 additions & 4 deletions crates/ratchet-core/src/ops/reindex/slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,10 +158,9 @@ def slice(a):
run_py_prg(prg.to_string(), &[a], &[], a.dt())
}

fn run_reindex_trial(prob: SliceProblem) -> anyhow::Result<()> {
fn run_reindex_trial(prob: SliceProblem, device: Device) -> anyhow::Result<()> {
let SliceProblem { op } = prob;
println!("SLICE PROBLEM: {:?}", op);
let device = Device::request_device(DeviceRequest::GPU).unwrap();
let a = op.src.clone();

let a_gpu = a.to(&device)?;
Expand All @@ -173,8 +172,16 @@ def slice(a):
}

#[proptest(cases = 16)]
fn test_slice(prob: SliceProblem) {
fn test_slice_gpu(prob: SliceProblem) {
let _ = env_logger::builder().is_test(true).try_init();
let device = Device::request_device(DeviceRequest::GPU).unwrap();
run_reindex_trial(prob, device).unwrap();
}

#[proptest(cases = 16)]
fn test_slice_cpu(prob: SliceProblem) {
let _ = env_logger::builder().is_test(true).try_init();
run_reindex_trial(prob).unwrap();
let device = Device::request_device(DeviceRequest::CPU).unwrap();
run_reindex_trial(prob, device).unwrap();
}
}

0 comments on commit b9d106f

Please sign in to comment.