Skip to content

Commit

Permalink
Merge pull request #259 from huggingface/yoink-cpu-struct
Browse files Browse the repository at this point in the history
Yoink cpu struct
  • Loading branch information
FL33TW00D authored Oct 4, 2024
2 parents c156344 + ca70c05 commit 721f3c6
Show file tree
Hide file tree
Showing 6 changed files with 171 additions and 187 deletions.
54 changes: 54 additions & 0 deletions crates/ratchet-core/src/cpu/binary.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
use crate::{
binary_apply_inplace, Binary, BinaryOp, CPUOperation, DType, OpGuards, Operation,
OperationError, RVec, StorageView, Tensor, TensorDType,
};
use core::marker::PhantomData;
use half::{bf16, f16};

pub struct BinaryOps<T: TensorDType> {
dtype: PhantomData<T>,
}

macro_rules! impl_cpu_binary_op {
($method_name:ident, $dtype:ident, $op:expr) => {
fn $method_name(lhs: &Tensor, rhs: &Tensor, dst: Tensor) -> Result<Tensor, OperationError> {
binary_apply_inplace::<$dtype>(lhs, rhs, &dst, $op)?;
Ok(dst)
}
};
}

macro_rules! impl_cpu_binary {
($dtype:ident) => {
impl BinaryOps<$dtype> {
impl_cpu_binary_op!(add, $dtype, |lhs, rhs| lhs + rhs);
impl_cpu_binary_op!(sub, $dtype, |lhs, rhs| lhs - rhs);
impl_cpu_binary_op!(mul, $dtype, |lhs, rhs| lhs * rhs);
impl_cpu_binary_op!(div, $dtype, |lhs, rhs| lhs / rhs);

pub fn apply(op: &Binary, dst: Tensor) -> Result<Tensor, OperationError> {
match op.op() {
BinaryOp::Add => Self::add(op.lhs(), op.rhs(), dst),
BinaryOp::Sub => Self::sub(op.lhs(), op.rhs(), dst),
BinaryOp::Mul => Self::mul(op.lhs(), op.rhs(), dst),
BinaryOp::Div => Self::div(op.lhs(), op.rhs(), dst),
}
}
}
};
}

impl CPUOperation for Binary {
fn apply_cpu(&self, dst: Tensor) -> Result<Tensor, OperationError> {
match dst.dt() {
DType::F32 => BinaryOps::<f32>::apply(self, dst),
DType::F16 => BinaryOps::<f16>::apply(self, dst),
DType::BF16 => BinaryOps::<bf16>::apply(self, dst),
_ => todo!(),
}
}
}

impl_cpu_binary!(f32);
impl_cpu_binary!(f16);
impl_cpu_binary!(bf16);
2 changes: 1 addition & 1 deletion crates/ratchet-core/src/cpu/gemm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ fn gemm_impl<T: TensorDType>(
}

impl CPUOperation for Matmul {
fn apply(&self, dst: Tensor) -> Result<Tensor, OperationError> {
fn apply_cpu(&self, dst: Tensor) -> Result<Tensor, OperationError> {
fn run_gemm<T: TensorDType>(
spec: MatmulSpec,
lhs: &Tensor,
Expand Down
185 changes: 23 additions & 162 deletions crates/ratchet-core/src/cpu/mod.rs
Original file line number Diff line number Diff line change
@@ -1,176 +1,37 @@
mod binary;
pub mod gemm;
mod unary;

use crate::{
dequantize, Binary, BinaryOp, CPUBuffer, CPUOperation, Cast, Concat, DType, IndexSelect,
InvariantError, OpGuards, Operation, OperationError, RVec, Storage, StorageView, Tensor,
TensorDType, Unary, UnaryOp,
dequantize, Binary, CPUBuffer, Cast, Concat, DType, IndexSelect, InvariantError, LazyOp,
Operation, OperationError, RVec, Storage, Tensor, TensorDType,
};
use anyhow::anyhow;
use bytemuck::NoUninit;
use core::marker::PhantomData;
use half::{bf16, f16};
use num_traits::Float;

#[derive(Debug)]
pub struct CPU<T: TensorDType, OP: Operation> {
op: OP,
dtype: PhantomData<T>,
}

impl<T: TensorDType, OP: Operation> CPU<T, OP> {
pub fn new(op: OP) -> Self {
Self {
op,
dtype: PhantomData,
}
}
}

impl<T: TensorDType, OP: Operation> OpGuards for CPU<T, OP> {
fn check_shapes(&self) {
self.op.check_shapes();
}

fn check_dtypes(&self) {
self.op.check_dtypes();
}
}

impl<T: TensorDType, OP: Operation> Operation for CPU<T, OP> {
fn name(&self) -> &'static str {
self.op.name()
}

fn compute_view(&self) -> Result<StorageView, OperationError> {
self.op.compute_view()
}

fn srcs(&self) -> RVec<&Tensor> {
self.op.srcs()
}
}

macro_rules! impl_cpu_unary_op {
($method_name:ident, $op:expr) => {
fn $method_name(input: &Tensor, dst: Tensor) -> Result<Tensor, OperationError> {
unary_apply_fn(input, &dst, $op)?;
Ok(dst)
}
};
}

macro_rules! impl_cpu_unary_wrapper {
($dtype:ident, $conv:expr) => {
impl CPU<$dtype, Unary> {
impl_cpu_unary_op!(gelu, |x: $dtype| $conv(0.5)
* x
* ($conv(1.0)
+ $dtype::tanh(
$conv(0.797_884_6) * x * ($conv(1.0) + $conv(0.044715) * x * x)
)));

impl_cpu_unary_op!(tanh, |x: $dtype| x.tanh());
impl_cpu_unary_op!(exp, |x: $dtype| x.exp());
impl_cpu_unary_op!(log, |x: $dtype| x.ln());
impl_cpu_unary_op!(sin, |x: $dtype| x.sin());
impl_cpu_unary_op!(cos, |x: $dtype| x.cos());
impl_cpu_unary_op!(abs, |x: $dtype| x.abs());
impl_cpu_unary_op!(sqrt, |x: $dtype| x.sqrt());
impl_cpu_unary_op!(relu, |x: $dtype| x.max($conv(0.0)));
impl_cpu_unary_op!(floor, |x: $dtype| x.floor());
impl_cpu_unary_op!(ceil, |x: $dtype| x.ceil());
impl_cpu_unary_op!(neg, |x: $dtype| -x);
impl_cpu_unary_op!(silu, |x: $dtype| x / ($conv(1.0) + (-x).exp()));
impl_cpu_unary_op!(sigmoid, |x: $dtype| $conv(1.0) / ($conv(1.0) + (-x).exp()));
}
};
}

macro_rules! impl_cpu_unary {
($dtype:ident) => {
impl_cpu_unary!($dtype, |x| x);
};
($dtype:ident, $conv:expr) => {
impl_cpu_unary_wrapper!($dtype, $conv);

impl CPUOperation for CPU<$dtype, Unary> {
fn apply(&self, dst: Tensor) -> Result<Tensor, OperationError> {
match self.op.op() {
UnaryOp::Gelu => Self::gelu(self.op.input(), dst),
UnaryOp::Tanh => Self::tanh(self.op.input(), dst),
UnaryOp::Exp => Self::exp(self.op.input(), dst),
UnaryOp::Log => Self::log(self.op.input(), dst),
UnaryOp::Sin => Self::sin(self.op.input(), dst),
UnaryOp::Cos => Self::cos(self.op.input(), dst),
UnaryOp::Abs => Self::abs(self.op.input(), dst),
UnaryOp::Sqrt => Self::sqrt(self.op.input(), dst),
UnaryOp::Relu => Self::relu(self.op.input(), dst),
UnaryOp::Floor => Self::floor(self.op.input(), dst),
UnaryOp::Ceil => Self::ceil(self.op.input(), dst),
UnaryOp::Neg => Self::neg(self.op.input(), dst),
UnaryOp::Silu => Self::silu(self.op.input(), dst),
UnaryOp::Sigmoid => Self::sigmoid(self.op.input(), dst),
}
}
}
};
}

impl_cpu_unary!(f32);
impl_cpu_unary!(f16, f16::from_f32);
impl_cpu_unary!(bf16, bf16::from_f32);

pub fn cpu_unary(unary: Unary, dst: Tensor) -> Result<Tensor, OperationError> {
match dst.dt() {
DType::F32 => CPU::<f32, _>::new(unary).apply(dst),
DType::F16 => CPU::<f16, _>::new(unary).apply(dst),
DType::BF16 => CPU::<bf16, _>::new(unary).apply(dst),
_ => todo!(),
pub fn apply_operation(op: LazyOp, dst: Tensor) -> Result<Tensor, OperationError> {
match op {
LazyOp::Binary(b) => b.apply_cpu(dst),
LazyOp::Cast(c) => cpu_cast(c, dst),
LazyOp::Matmul(m) => m.apply_cpu(dst),
LazyOp::Softmax(_s) => todo!(),
LazyOp::RoPE(_r) => todo!(),
LazyOp::Unary(u) => u.apply_cpu(dst),
LazyOp::Reindex(_r) => todo!(),
LazyOp::Concat(c) => cpu_concat(c, dst),
LazyOp::Norm(_n) => todo!(),
LazyOp::Conv(_c) => todo!(),
LazyOp::Select(i) => cpu_index_select(i, dst),
LazyOp::IndexWrite(_i) => todo!(),
LazyOp::Cache(_c) => todo!(),
LazyOp::Const => todo!(),
LazyOp::View(_) => todo!(),
}
}

macro_rules! impl_cpu_binary_op {
($method_name:ident, $dtype:ident, $op:expr) => {
fn $method_name(lhs: &Tensor, rhs: &Tensor, dst: Tensor) -> Result<Tensor, OperationError> {
binary_apply_inplace::<$dtype>(lhs, rhs, &dst, $op)?;
Ok(dst)
}
};
}

macro_rules! impl_cpu_binary {
($dtype:ident) => {
impl CPU<$dtype, Binary> {
impl_cpu_binary_op!(add, $dtype, |lhs, rhs| lhs + rhs);
impl_cpu_binary_op!(sub, $dtype, |lhs, rhs| lhs - rhs);
impl_cpu_binary_op!(mul, $dtype, |lhs, rhs| lhs * rhs);
impl_cpu_binary_op!(div, $dtype, |lhs, rhs| lhs / rhs);
}

impl CPUOperation for CPU<$dtype, Binary> {
fn apply(&self, dst: Tensor) -> Result<Tensor, OperationError> {
match self.op.op() {
BinaryOp::Add => Self::add(self.op.lhs(), self.op.rhs(), dst),
BinaryOp::Sub => Self::sub(self.op.lhs(), self.op.rhs(), dst),
BinaryOp::Mul => Self::mul(self.op.lhs(), self.op.rhs(), dst),
BinaryOp::Div => Self::div(self.op.lhs(), self.op.rhs(), dst),
}
}
}
};
}

impl_cpu_binary!(f32);
impl_cpu_binary!(f16);
impl_cpu_binary!(bf16);

pub fn cpu_binary(binary: Binary, dst: Tensor) -> Result<Tensor, OperationError> {
match dst.dt() {
DType::F32 => CPU::<f32, _>::new(binary).apply(dst),
DType::F16 => CPU::<f16, _>::new(binary).apply(dst),
DType::BF16 => CPU::<bf16, _>::new(binary).apply(dst),
_ => todo!(),
}
pub trait CPUOperation: Operation {
fn apply_cpu(&self, dst: Tensor) -> Result<Tensor, OperationError>;
}

fn index_select<T: TensorDType>(
Expand Down
89 changes: 89 additions & 0 deletions crates/ratchet-core/src/cpu/unary.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
use crate::{
unary_apply_fn, CPUOperation, DType, OperationError, Tensor, TensorDType, Unary, UnaryOp,
};
use core::marker::PhantomData;
use half::{bf16, f16};
use num_traits::Float;

struct UnaryOps<T: TensorDType> {
dtype: PhantomData<T>,
}

macro_rules! impl_unary_ops {
($dtype:ident, $conv:expr) => {
impl UnaryOps<$dtype> {
impl_cpu_unary_op!(gelu, |x: $dtype| $conv(0.5)
* x
* ($conv(1.0)
+ $dtype::tanh(
$conv(0.797_884_6) * x * ($conv(1.0) + $conv(0.044715) * x * x)
)));

impl_cpu_unary_op!(tanh, |x: $dtype| x.tanh());
impl_cpu_unary_op!(exp, |x: $dtype| x.exp());
impl_cpu_unary_op!(log, |x: $dtype| x.ln());
impl_cpu_unary_op!(sin, |x: $dtype| x.sin());
impl_cpu_unary_op!(cos, |x: $dtype| x.cos());
impl_cpu_unary_op!(abs, |x: $dtype| x.abs());
impl_cpu_unary_op!(sqrt, |x: $dtype| x.sqrt());
impl_cpu_unary_op!(relu, |x: $dtype| x.max($conv(0.0)));
impl_cpu_unary_op!(floor, |x: $dtype| x.floor());
impl_cpu_unary_op!(ceil, |x: $dtype| x.ceil());
impl_cpu_unary_op!(neg, |x: $dtype| -x);
impl_cpu_unary_op!(silu, |x: $dtype| x / ($conv(1.0) + (-x).exp()));
impl_cpu_unary_op!(sigmoid, |x: $dtype| $conv(1.0) / ($conv(1.0) + (-x).exp()));

fn apply(op: &Unary, dst: Tensor) -> Result<Tensor, OperationError> {
match op.op() {
UnaryOp::Gelu => Self::gelu(op.input(), dst),
UnaryOp::Tanh => Self::tanh(op.input(), dst),
UnaryOp::Exp => Self::exp(op.input(), dst),
UnaryOp::Log => Self::log(op.input(), dst),
UnaryOp::Sin => Self::sin(op.input(), dst),
UnaryOp::Cos => Self::cos(op.input(), dst),
UnaryOp::Abs => Self::abs(op.input(), dst),
UnaryOp::Sqrt => Self::sqrt(op.input(), dst),
UnaryOp::Relu => Self::relu(op.input(), dst),
UnaryOp::Floor => Self::floor(op.input(), dst),
UnaryOp::Ceil => Self::ceil(op.input(), dst),
UnaryOp::Neg => Self::neg(op.input(), dst),
UnaryOp::Silu => Self::silu(op.input(), dst),
UnaryOp::Sigmoid => Self::sigmoid(op.input(), dst),
}
}
}
};
}

macro_rules! impl_cpu_unary_op {
($method_name:ident, $op:expr) => {
fn $method_name(input: &Tensor, dst: Tensor) -> Result<Tensor, OperationError> {
unary_apply_fn(input, &dst, $op)?;
Ok(dst)
}
};
}

impl CPUOperation for Unary {
fn apply_cpu(&self, dst: Tensor) -> Result<Tensor, OperationError> {
match dst.dt() {
DType::F32 => UnaryOps::<f32>::apply(self, dst),
DType::F16 => UnaryOps::<f16>::apply(self, dst),
DType::BF16 => UnaryOps::<bf16>::apply(self, dst),
_ => todo!(),
}
}
}

macro_rules! impl_cpu_unary {
($dtype:ident) => {
impl_cpu_unary!($dtype, |x| x);
};
($dtype:ident, $conv:expr) => {
impl_unary_ops!($dtype, $conv);
};
}

impl_cpu_unary!(f32);
impl_cpu_unary!(f16, f16::from_f32);
impl_cpu_unary!(bf16, bf16::from_f32);
4 changes: 0 additions & 4 deletions crates/ratchet-core/src/op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,3 @@ pub trait GPUOperation: Operation {
))
}
}

pub trait CPUOperation: Operation {
fn apply(&self, dst: Tensor) -> Result<Tensor, OperationError>;
}
24 changes: 4 additions & 20 deletions crates/ratchet-core/src/tensor.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use crate::gpu::{BindGroupEntry, CpuUniform, WgpuDevice};
use crate::{
cpu::*, ops::*, rvec, BufferSegment, CPUBuffer, CPUOperation, CompiledOp, DType, Device,
DeviceStorage, Executable, GPUBuffer, GPUOperation, InvariantError, LazyOp, Operation,
OperationError, RVec, RawCPUBuffer, Shape, Storage, Strides, TensorDType, TensorId,
cpu, ops::*, rvec, BufferSegment, CPUBuffer, CompiledOp, DType, Device, DeviceStorage,
Executable, GPUBuffer, GPUOperation, InvariantError, LazyOp, Operation, OperationError, RVec,
RawCPUBuffer, Shape, Storage, Strides, TensorDType, TensorId,
};
use derive_new::new;
use npyz::WriterBuilder;
Expand Down Expand Up @@ -747,23 +747,7 @@ impl Tensor {
}

pub fn cpu_apply(self, dst: Tensor) -> Option<Tensor> {
match self.op().clone() {
LazyOp::Binary(b) => cpu_binary(b, dst).ok(),
LazyOp::Cast(c) => cpu_cast(c, dst).ok(),
LazyOp::Matmul(m) => m.apply(dst).ok(),
LazyOp::Softmax(_s) => todo!(),
LazyOp::RoPE(_r) => todo!(),
LazyOp::Unary(u) => cpu_unary(u, dst).ok(),
LazyOp::Reindex(_r) => todo!(),
LazyOp::Concat(c) => cpu_concat(c, dst).ok(),
LazyOp::Norm(_n) => todo!(),
LazyOp::Conv(_c) => todo!(),
LazyOp::Select(i) => cpu_index_select(i, dst).ok(),
LazyOp::IndexWrite(_i) => todo!(),
LazyOp::Cache(_c) => todo!(),
LazyOp::Const => None,
LazyOp::View(_) => None,
}
cpu::apply_operation(self.op().clone(), dst).ok()
}

fn resolve_inner(self, debug: bool) -> Result<Tensor, TensorError> {
Expand Down

0 comments on commit 721f3c6

Please sign in to comment.