Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/cpu rms and layer normalization #264

Merged
merged 8 commits into from
Nov 20, 2024
77 changes: 73 additions & 4 deletions crates/ratchet-core/src/cpu/binary.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,64 @@
use crate::{
binary_apply_inplace, Binary, BinaryOp, CPUOperation, DType, OpGuards, Operation,
OperationError, RVec, StorageView, Tensor, TensorDType,
};
use crate::cpu::cpu_store_result;
use crate::{Binary, BinaryOp, CPUOperation, DType, OperationError, Tensor, TensorDType};
use core::marker::PhantomData;
use half::{bf16, f16};
use num_traits::NumOps;

#[inline]
pub(crate) fn binary_map<T: TensorDType, U: TensorDType>(
lhs: &[T],
rhs: &[T],
dst: &mut [U],
f: fn(T, T) -> U,
) {
assert_eq!(lhs.len(), dst.len());
assert_eq!(rhs.len(), dst.len());
for ((l, r), d) in lhs
.iter()
.copied()
.zip(rhs.iter().copied())
.zip(dst.iter_mut())
{
*d = f(l, r);
}
}

#[inline]
pub(crate) fn binary_map_inplace<T: TensorDType>(lhs: &mut [T], rhs: &[T], f: fn(T, T) -> T) {
assert_eq!(lhs.len(), rhs.len());
lhs.iter_mut().zip(rhs.iter()).for_each(|(l, r)| {
*l = f(*l, *r);
});
}

#[inline]
pub(crate) fn binary_apply<T: TensorDType, U: TensorDType>(
lhs: &Tensor,
rhs: &Tensor,
dst: &Tensor,
f: fn(T, T) -> U,
) -> Result<(), OperationError> {
let lhs = lhs.to_vec::<T>()?;
let rhs = rhs.to_vec::<T>()?;
let mut result = vec![U::zero(); dst.shape().numel()];
binary_map(&lhs, &rhs, &mut result, f);
cpu_store_result(dst, &result);
Ok(())
}

#[inline]
pub(crate) fn binary_apply_inplace<T: TensorDType>(
lhs: &Tensor,
rhs: &Tensor,
dst: &Tensor,
f: fn(T, T) -> T,
) -> Result<(), OperationError> {
let mut lhs = lhs.to_vec::<T>()?;
let rhs = rhs.to_vec::<T>()?;
binary_map_inplace(&mut lhs, &rhs, f);
cpu_store_result(dst, &lhs);
Ok(())
}

pub struct BinaryOps<T: TensorDType> {
dtype: PhantomData<T>,
Expand All @@ -18,6 +73,20 @@ macro_rules! impl_cpu_binary_op {
};
}

macro_rules! cpu_binary_op_fn {
($method_name:ident, $op:expr) => {
#[inline]
pub(crate) fn $method_name<T: TensorDType + NumOps>(lhs: &mut [T], rhs: &[T]) {
binary_map_inplace::<T>(lhs, rhs, $op);
}
};
}

cpu_binary_op_fn!(add, |lhs, rhs| lhs + rhs);
cpu_binary_op_fn!(sub, |lhs, rhs| lhs - rhs);
cpu_binary_op_fn!(mul, |lhs, rhs| lhs * rhs);
cpu_binary_op_fn!(div, |lhs, rhs| lhs / rhs);

macro_rules! impl_cpu_binary {
($dtype:ident) => {
impl BinaryOps<$dtype> {
Expand Down
88 changes: 5 additions & 83 deletions crates/ratchet-core/src/cpu/mod.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,19 @@
mod binary;
pub mod gemm;
mod norm;
pub mod reindex;
pub mod rope;
mod unary;
mod utils;

use crate::{
dequantize, Binary, BinaryOp, CPUBuffer, Cast, Concat, DType, IndexSelect, InvariantError,
LazyOp, OpGuards, Operation, OperationError, RVec, Shape, Storage, StorageView, Strides,
Tensor, TensorDType, Unary, UnaryOp,
dequantize, Cast, Concat, DType, IndexSelect, InvariantError, LazyOp, Operation,
OperationError, RVec, Shape, Tensor, TensorDType,
};
use anyhow::anyhow;
use core::marker::PhantomData;
use half::{bf16, f16};
use num_traits::Float;
use rope::cpu_rope;
use unary::unary_apply_fn;
use utils::cpu_store_result;

pub fn apply_operation(op: LazyOp, dst: Tensor) -> Result<Tensor, OperationError> {
Expand All @@ -27,7 +26,7 @@ pub fn apply_operation(op: LazyOp, dst: Tensor) -> Result<Tensor, OperationError
LazyOp::Unary(u) => u.apply_cpu(dst),
LazyOp::Reindex(r) => r.apply_cpu(dst),
LazyOp::Concat(c) => cpu_concat(c, dst),
LazyOp::Norm(_n) => todo!(),
LazyOp::Norm(n) => n.apply_cpu(dst),
LazyOp::Conv(_c) => todo!(),
LazyOp::Select(i) => cpu_index_select(i, dst),
LazyOp::IndexWrite(_i) => todo!(),
Expand Down Expand Up @@ -209,80 +208,3 @@ pub fn cpu_concat(Concat { inputs, dim }: Concat, dst: Tensor) -> Result<Tensor,
dtype => Err(InvariantError::UnsupportedDType(dtype).into()),
}
}

#[inline]
fn unary_apply_fn_helper<T: TensorDType, U: TensorDType>(src: &[T], dst: &mut [U], f: fn(T) -> U) {
assert_eq!(src.len(), dst.len());
for (s, d) in src.iter().copied().zip(dst.iter_mut()) {
*d = f(s);
}
}

#[inline]
pub fn unary_apply_fn<T: TensorDType, U: TensorDType>(
input: &Tensor,
dst: &Tensor,
f: fn(T) -> U,
) -> Result<(), OperationError> {
let input = input.to_vec::<T>()?;
let mut result = vec![U::zero(); dst.shape().numel()];
unary_apply_fn_helper(&input, &mut result, f);
cpu_store_result(dst, &result);
Ok(())
}

#[inline]
fn binary_apply_fn_helper<T: TensorDType, U: TensorDType>(
lhs: &[T],
rhs: &[T],
dst: &mut [U],
f: fn(T, T) -> U,
) {
assert_eq!(lhs.len(), dst.len());
assert_eq!(rhs.len(), dst.len());
for ((l, r), d) in lhs
.iter()
.copied()
.zip(rhs.iter().copied())
.zip(dst.iter_mut())
{
*d = f(l, r);
}
}

#[inline]
fn binary_apply_inplace_helper<T: TensorDType>(lhs: &mut [T], rhs: &[T], f: fn(T, T) -> T) {
assert_eq!(lhs.len(), rhs.len());
lhs.iter_mut().zip(rhs.iter()).for_each(|(l, r)| {
*l = f(*l, *r);
});
}

#[inline]
pub fn binary_apply_fn<T: TensorDType, U: TensorDType>(
lhs: &Tensor,
rhs: &Tensor,
dst: &Tensor,
f: fn(T, T) -> U,
) -> Result<(), OperationError> {
let lhs = lhs.to_vec::<T>()?;
let rhs = rhs.to_vec::<T>()?;
let mut result = vec![U::zero(); dst.shape().numel()];
binary_apply_fn_helper(&lhs, &rhs, &mut result, f);
cpu_store_result(dst, &result);
Ok(())
}

#[inline]
pub fn binary_apply_inplace<T: TensorDType>(
lhs: &Tensor,
rhs: &Tensor,
dst: &Tensor,
f: fn(T, T) -> T,
) -> Result<(), OperationError> {
let mut lhs = lhs.to_vec::<T>()?;
let rhs = rhs.to_vec::<T>()?;
binary_apply_inplace_helper(&mut lhs, &rhs, f);
cpu_store_result(dst, &lhs);
Ok(())
}
Loading
Loading