Skip to content

Commit

Permalink
Merge pull request #264 from huggingface/feature/cpu-layernorm
Browse files Browse the repository at this point in the history
Feature/cpu rms and layer normalization
  • Loading branch information
FL33TW00D authored Nov 20, 2024
2 parents 560ccea + 8292992 commit 6d8fb10
Show file tree
Hide file tree
Showing 9 changed files with 377 additions and 116 deletions.
77 changes: 73 additions & 4 deletions crates/ratchet-core/src/cpu/binary.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,64 @@
use crate::{
binary_apply_inplace, Binary, BinaryOp, CPUOperation, DType, OpGuards, Operation,
OperationError, RVec, StorageView, Tensor, TensorDType,
};
use crate::cpu::cpu_store_result;
use crate::{Binary, BinaryOp, CPUOperation, DType, OperationError, Tensor, TensorDType};
use core::marker::PhantomData;
use half::{bf16, f16};
use num_traits::NumOps;

#[inline]
pub(crate) fn binary_map<T: TensorDType, U: TensorDType>(
lhs: &[T],
rhs: &[T],
dst: &mut [U],
f: fn(T, T) -> U,
) {
assert_eq!(lhs.len(), dst.len());
assert_eq!(rhs.len(), dst.len());
for ((l, r), d) in lhs
.iter()
.copied()
.zip(rhs.iter().copied())
.zip(dst.iter_mut())
{
*d = f(l, r);
}
}

#[inline]
pub(crate) fn binary_map_inplace<T: TensorDType>(lhs: &mut [T], rhs: &[T], f: fn(T, T) -> T) {
assert_eq!(lhs.len(), rhs.len());
lhs.iter_mut().zip(rhs.iter()).for_each(|(l, r)| {
*l = f(*l, *r);
});
}

#[inline]
pub(crate) fn binary_apply<T: TensorDType, U: TensorDType>(
lhs: &Tensor,
rhs: &Tensor,
dst: &Tensor,
f: fn(T, T) -> U,
) -> Result<(), OperationError> {
let lhs = lhs.to_vec::<T>()?;
let rhs = rhs.to_vec::<T>()?;
let mut result = vec![U::zero(); dst.shape().numel()];
binary_map(&lhs, &rhs, &mut result, f);
cpu_store_result(dst, &result);
Ok(())
}

#[inline]
pub(crate) fn binary_apply_inplace<T: TensorDType>(
lhs: &Tensor,
rhs: &Tensor,
dst: &Tensor,
f: fn(T, T) -> T,
) -> Result<(), OperationError> {
let mut lhs = lhs.to_vec::<T>()?;
let rhs = rhs.to_vec::<T>()?;
binary_map_inplace(&mut lhs, &rhs, f);
cpu_store_result(dst, &lhs);
Ok(())
}

pub struct BinaryOps<T: TensorDType> {
dtype: PhantomData<T>,
Expand All @@ -18,6 +73,20 @@ macro_rules! impl_cpu_binary_op {
};
}

macro_rules! cpu_binary_op_fn {
($method_name:ident, $op:expr) => {
#[inline]
pub(crate) fn $method_name<T: TensorDType + NumOps>(lhs: &mut [T], rhs: &[T]) {
binary_map_inplace::<T>(lhs, rhs, $op);
}
};
}

cpu_binary_op_fn!(add, |lhs, rhs| lhs + rhs);
cpu_binary_op_fn!(sub, |lhs, rhs| lhs - rhs);
cpu_binary_op_fn!(mul, |lhs, rhs| lhs * rhs);
cpu_binary_op_fn!(div, |lhs, rhs| lhs / rhs);

macro_rules! impl_cpu_binary {
($dtype:ident) => {
impl BinaryOps<$dtype> {
Expand Down
88 changes: 5 additions & 83 deletions crates/ratchet-core/src/cpu/mod.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,19 @@
mod binary;
pub mod gemm;
mod norm;
pub mod reindex;
pub mod rope;
mod unary;
mod utils;

use crate::{
dequantize, Binary, BinaryOp, CPUBuffer, Cast, Concat, DType, IndexSelect, InvariantError,
LazyOp, OpGuards, Operation, OperationError, RVec, Shape, Storage, StorageView, Strides,
Tensor, TensorDType, Unary, UnaryOp,
dequantize, Cast, Concat, DType, IndexSelect, InvariantError, LazyOp, Operation,
OperationError, RVec, Shape, Tensor, TensorDType,
};
use anyhow::anyhow;
use core::marker::PhantomData;
use half::{bf16, f16};
use num_traits::Float;
use rope::cpu_rope;
use unary::unary_apply_fn;
use utils::cpu_store_result;

pub fn apply_operation(op: LazyOp, dst: Tensor) -> Result<Tensor, OperationError> {
Expand All @@ -27,7 +26,7 @@ pub fn apply_operation(op: LazyOp, dst: Tensor) -> Result<Tensor, OperationError
LazyOp::Unary(u) => u.apply_cpu(dst),
LazyOp::Reindex(r) => r.apply_cpu(dst),
LazyOp::Concat(c) => cpu_concat(c, dst),
LazyOp::Norm(_n) => todo!(),
LazyOp::Norm(n) => n.apply_cpu(dst),
LazyOp::Conv(_c) => todo!(),
LazyOp::Select(i) => cpu_index_select(i, dst),
LazyOp::IndexWrite(_i) => todo!(),
Expand Down Expand Up @@ -209,80 +208,3 @@ pub fn cpu_concat(Concat { inputs, dim }: Concat, dst: Tensor) -> Result<Tensor,
dtype => Err(InvariantError::UnsupportedDType(dtype).into()),
}
}

#[inline]
fn unary_apply_fn_helper<T: TensorDType, U: TensorDType>(src: &[T], dst: &mut [U], f: fn(T) -> U) {
assert_eq!(src.len(), dst.len());
for (s, d) in src.iter().copied().zip(dst.iter_mut()) {
*d = f(s);
}
}

#[inline]
pub fn unary_apply_fn<T: TensorDType, U: TensorDType>(
input: &Tensor,
dst: &Tensor,
f: fn(T) -> U,
) -> Result<(), OperationError> {
let input = input.to_vec::<T>()?;
let mut result = vec![U::zero(); dst.shape().numel()];
unary_apply_fn_helper(&input, &mut result, f);
cpu_store_result(dst, &result);
Ok(())
}

#[inline]
fn binary_apply_fn_helper<T: TensorDType, U: TensorDType>(
lhs: &[T],
rhs: &[T],
dst: &mut [U],
f: fn(T, T) -> U,
) {
assert_eq!(lhs.len(), dst.len());
assert_eq!(rhs.len(), dst.len());
for ((l, r), d) in lhs
.iter()
.copied()
.zip(rhs.iter().copied())
.zip(dst.iter_mut())
{
*d = f(l, r);
}
}

#[inline]
fn binary_apply_inplace_helper<T: TensorDType>(lhs: &mut [T], rhs: &[T], f: fn(T, T) -> T) {
assert_eq!(lhs.len(), rhs.len());
lhs.iter_mut().zip(rhs.iter()).for_each(|(l, r)| {
*l = f(*l, *r);
});
}

#[inline]
pub fn binary_apply_fn<T: TensorDType, U: TensorDType>(
lhs: &Tensor,
rhs: &Tensor,
dst: &Tensor,
f: fn(T, T) -> U,
) -> Result<(), OperationError> {
let lhs = lhs.to_vec::<T>()?;
let rhs = rhs.to_vec::<T>()?;
let mut result = vec![U::zero(); dst.shape().numel()];
binary_apply_fn_helper(&lhs, &rhs, &mut result, f);
cpu_store_result(dst, &result);
Ok(())
}

#[inline]
pub fn binary_apply_inplace<T: TensorDType>(
lhs: &Tensor,
rhs: &Tensor,
dst: &Tensor,
f: fn(T, T) -> T,
) -> Result<(), OperationError> {
let mut lhs = lhs.to_vec::<T>()?;
let rhs = rhs.to_vec::<T>()?;
binary_apply_inplace_helper(&mut lhs, &rhs, f);
cpu_store_result(dst, &lhs);
Ok(())
}
Loading

0 comments on commit 6d8fb10

Please sign in to comment.