Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/cpu rope #256

Merged
merged 34 commits into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
ab223ca
initial rope setup
ivarflakstad Sep 6, 2024
2534dfb
tidy
ivarflakstad Sep 6, 2024
f7d3a32
chore: refactor cpu gemm for use in rope
ivarflakstad Sep 6, 2024
5f15257
debugging cpu RoPE
ivarflakstad Sep 6, 2024
70dc153
Rope cpu is almost there
ivarflakstad Sep 6, 2024
d1f31da
debugging gemm/rope interaction
ivarflakstad Sep 9, 2024
2366e9d
More debugging gemm/rope
ivarflakstad Sep 9, 2024
9b4bd6c
Revert gemm back tobinary mul
ivarflakstad Sep 9, 2024
f998176
chore: turns out while I was confused, it was not about gemm
ivarflakstad Sep 9, 2024
a09f986
chore: Add interleave_by_offset
ivarflakstad Sep 9, 2024
46861fa
close
ivarflakstad Sep 9, 2024
23f2687
Most RoPE test cases are passing
ivarflakstad Sep 9, 2024
328d8ce
getting there
ivarflakstad Sep 18, 2024
6e39c34
Merge branch 'master' into feature/cpu-rope
ivarflakstad Sep 20, 2024
33b097e
testing a bunch of different things. really messy :)
ivarflakstad Oct 1, 2024
d5fb9f8
chore: focus on theta
FL33TW00D Oct 2, 2024
44bc1ec
chore: theta matches
FL33TW00D Oct 2, 2024
81f4bfc
chore: theta matches
FL33TW00D Oct 2, 2024
82435eb
chore: R1 and R2 match
FL33TW00D Oct 3, 2024
ca5f5a7
chore: cleaning
FL33TW00D Oct 3, 2024
88e7c07
chore: RoPE works but is shit
FL33TW00D Oct 4, 2024
4d63692
chore: RoPE doesn't work
FL33TW00D Oct 4, 2024
1d93205
chore: not quite right
FL33TW00D Oct 4, 2024
572e7d1
chore: rope concat dynamic outs length
ivarflakstad Oct 16, 2024
ce991ba
chore: simplify rope concat
ivarflakstad Oct 16, 2024
67a40c9
chore: padding r1/r2 with 0s works. Not optimal
ivarflakstad Oct 18, 2024
c932fd5
chore: use randn in rope test to avoid precision issues
ivarflakstad Oct 20, 2024
022db5a
chore: remove redundant "outs" vec
ivarflakstad Oct 20, 2024
508b5ed
chore: use iter cycle instead of % check
ivarflakstad Oct 20, 2024
52863d2
Merge branch 'master' into feature/cpu-rope
ivarflakstad Oct 28, 2024
be77442
chore: remove unused strided iterator. may be useful later
ivarflakstad Oct 28, 2024
5a018ce
chore: tidy up
ivarflakstad Oct 29, 2024
99a074e
chore: ? > unwrap
ivarflakstad Oct 29, 2024
b74e4b2
chore: Add back default debug_struct in Debug for Tensor impl
ivarflakstad Oct 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 43 additions & 16 deletions crates/ratchet-core/src/cpu/gemm.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use crate::{
cpu_store_result, CPUOperation, DType, InvariantError, Matmul, MatmulSpec, OperationError,
Shape, Tensor, TensorDType,
cpu::cpu_store_result, CPUOperation, DType, InvariantError, Matmul, MatmulSpec, OperationError,
Shape, Strides, Tensor, TensorDType,
};
use anyhow::{anyhow, Result};
use core::str::FromStr;
use gemm::{gemm, Parallelism};
use gemm::{gemm as gemm_kernel, Parallelism};
use half::{bf16, f16};
use std::num::NonZeroUsize;

Expand Down Expand Up @@ -56,21 +56,19 @@ fn calculate_skips(
Ok((lhs_skip, rhs_skip))
}

fn gemm_impl<T: TensorDType>(
spec: MatmulSpec,
pub(crate) fn gemm<T: TensorDType>(
lhs: &[T],
lhs_shape: &Shape,
lhs_strides: &Strides,
rhs: &[T],
rhs_shape: &Shape,
rhs_strides: &Strides,
dst_strides: &Strides,
b: usize,
m: usize,
n: usize,
k: usize,
) -> Result<Vec<T>, OperationError> {
let lhs_shape = spec.lhs_shape();
let rhs_shape = spec.rhs_shape();
let lhs_strides = spec.lhs_strides();
let rhs_strides = spec.rhs_strides();
let dst_strides = spec.dst_strides();
let b = spec.stacks();
let m = spec.m();
let n = spec.n();
let k = spec.k();

let lhs_strides = lhs_strides.to_vec();
let rhs_strides = rhs_strides.to_vec();
let rank = lhs_shape.rank();
Expand Down Expand Up @@ -102,7 +100,7 @@ fn gemm_impl<T: TensorDType>(
let rhs_p = &rhs[step * rhs_skip..];
let dst_p = &mut dst[step * dst_skip..];
unsafe {
gemm(
gemm_kernel(
m,
n,
k,
Expand All @@ -128,6 +126,35 @@ fn gemm_impl<T: TensorDType>(
Ok(dst)
}

fn gemm_impl<T: TensorDType>(
spec: MatmulSpec,
lhs: &[T],
rhs: &[T],
) -> Result<Vec<T>, OperationError> {
let lhs_shape = spec.lhs_shape();
let rhs_shape = spec.rhs_shape();
let lhs_strides = spec.lhs_strides();
let rhs_strides = spec.rhs_strides();
let dst_strides = spec.dst_strides();
let b = spec.stacks();
let m = spec.m();
let n = spec.n();
let k = spec.k();
gemm(
lhs,
lhs_shape,
lhs_strides,
rhs,
rhs_shape,
rhs_strides,
dst_strides,
b,
m,
n,
k,
)
}

impl CPUOperation for Matmul {
fn apply_cpu(&self, dst: Tensor) -> Result<Tensor, OperationError> {
fn run_gemm<T: TensorDType>(
Expand Down
73 changes: 44 additions & 29 deletions crates/ratchet-core/src/cpu/mod.rs
Original file line number Diff line number Diff line change
@@ -1,22 +1,28 @@
mod binary;
pub mod gemm;
pub mod rope;
mod unary;
mod utils;

use crate::{
dequantize, Binary, CPUBuffer, Cast, Concat, DType, IndexSelect, InvariantError, LazyOp,
Operation, OperationError, RVec, Storage, Tensor, TensorDType,
dequantize, Binary, BinaryOp, CPUBuffer, Cast, Concat, DType, IndexSelect, InvariantError,
LazyOp, OpGuards, Operation, OperationError, RVec, Shape, Storage, StorageView, Strides,
Tensor, TensorDType, Unary, UnaryOp,
};
use anyhow::anyhow;
use bytemuck::NoUninit;
use core::marker::PhantomData;
use half::{bf16, f16};
use num_traits::Float;
use rope::cpu_rope;
use utils::cpu_store_result;

pub fn apply_operation(op: LazyOp, dst: Tensor) -> Result<Tensor, OperationError> {
match op {
LazyOp::Binary(b) => b.apply_cpu(dst),
LazyOp::Cast(c) => cpu_cast(c, dst),
LazyOp::Matmul(m) => m.apply_cpu(dst),
LazyOp::Softmax(_s) => todo!(),
LazyOp::RoPE(_r) => todo!(),
LazyOp::RoPE(r) => cpu_rope(r, dst),
LazyOp::Unary(u) => u.apply_cpu(dst),
LazyOp::Reindex(_r) => todo!(),
LazyOp::Concat(c) => cpu_concat(c, dst),
Expand Down Expand Up @@ -148,44 +154,57 @@ pub fn cpu_cast(cast: Cast, dst: Tensor) -> Result<Tensor, OperationError> {
Ok(dst)
}

fn concat_inner<T: TensorDType>(
inputs: RVec<Tensor>,
pub(crate) fn concat<T: TensorDType>(
inputs: &[(Shape, Vec<T>)],
dim: usize,
dst: Tensor,
) -> Result<Tensor, OperationError> {
let dst_size = dst.shape().clone().product();
let mut result = vec![T::zero(); dst_size];

let dst_dim_len = dst.shape()[dim];
let block: usize = dst.shape().iter().skip(1 + dim).product();
dst_shape: &Shape,
dst: &mut [T],
) -> Result<(), OperationError> {
let dst_dim_len = dst_shape[dim];
let block: usize = dst_shape.iter().skip(1 + dim).product();
let dst_s = block * dst_dim_len;
let src_o = 0;
let mut dst_o = 0;
for t in inputs {
let src = t.to_vec::<T>()?;

let t_dims = t.shape().as_slice();
let a_dim: usize = t_dims.iter().take(dim).product();
let b_dim = block * t_dims[dim];

for (src_s, src) in inputs {
let a_dim: usize = src_s.iter().take(dim).product();
let b_dim = block * src_s[dim];
for idx in 0..a_dim {
let dst_idx = idx * dst_s + dst_o;
let src_idx = idx * b_dim + src_o;
let dst = &mut result[dst_idx..dst_idx + b_dim];
let dst_t = &mut dst[dst_idx..dst_idx + b_dim];
let src = &src[src_idx..src_idx + b_dim];
dst.copy_from_slice(src)
dst_t.copy_from_slice(src)
}
dst_o += b_dim;
}
Ok(())
}
pub(crate) fn apply_concat<T: TensorDType>(
inputs: RVec<Tensor>,
dim: usize,
dst: Tensor,
) -> Result<Tensor, OperationError> {
let dst_size = dst.shape().numel();
let mut result = vec![T::zero(); dst_size];

let inputs = inputs
.iter()
.map(|t| match t.to_vec::<T>() {
Ok(v) => Ok((t.shape().clone(), v)),
Err(e) => Err(e.into()),
})
.collect::<Result<Vec<_>, OperationError>>();

concat(&inputs?, dim, dst.shape(), &mut result)?;
cpu_store_result(&dst, &result);
Ok(dst)
}

pub fn cpu_concat(Concat { inputs, dim }: Concat, dst: Tensor) -> Result<Tensor, OperationError> {
match dst.dt() {
DType::F32 => concat_inner::<f32>(inputs, dim, dst),
DType::F16 => concat_inner::<f16>(inputs, dim, dst),
DType::BF16 => concat_inner::<bf16>(inputs, dim, dst),
DType::F32 => apply_concat::<f32>(inputs, dim, dst),
DType::F16 => apply_concat::<f16>(inputs, dim, dst),
DType::BF16 => apply_concat::<bf16>(inputs, dim, dst),
dtype => Err(InvariantError::UnsupportedDType(dtype).into()),
}
}
Expand Down Expand Up @@ -266,7 +285,3 @@ pub fn binary_apply_inplace<T: TensorDType>(
cpu_store_result(dst, &lhs);
Ok(())
}

pub fn cpu_store_result<T: NoUninit>(dst: &Tensor, data: &[T]) {
dst.update_storage(Storage::CPU(CPUBuffer::from_slice(data, dst.shape())));
}
170 changes: 170 additions & 0 deletions crates/ratchet-core/src/cpu/rope.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
use crate::{
concat,
cpu::{cpu_store_result, gemm::gemm},
shape, DType, OperationError, RoPE, Shape, Strides, Tensor,
};
use anyhow::anyhow;

pub fn cpu_rope(op: RoPE, dst: Tensor) -> Result<Tensor, OperationError> {
match op.input().dt() {
DType::F32 => {
let dim = op.dim();
let base = op.base();
let offset = op.offset();
let src = op.input().to_vec::<f32>()?;
let result = rope(src, op.input().shape(), dim, base, offset)?;
cpu_store_result(&dst, &result)
}
_ => todo!(),
}

Ok(dst)
}

fn compute_theta(
dim: usize,
seq_len: usize,
base: f32,
offset: usize,
) -> Result<Vec<f32>, OperationError> {
let half_dim = dim / 2;

let positions = (offset..seq_len + offset)
.map(|x| x as f32)
.collect::<Vec<f32>>();

let inv_freqs = (0..half_dim)
.map(|i| -(i as f32))
.map(|i| i * base.ln() / half_dim as f32)
.map(f32::exp)
.collect::<Vec<f32>>();

let p_shape = shape!(seq_len, 1);
let p_strides = Strides::from(&p_shape);
let i_shape = shape!(1, half_dim);
let i_strides = Strides::from(&i_shape);
let dst_strides = Strides::from(&shape!(seq_len, half_dim));
let theta = gemm(
&positions,
&p_shape,
&p_strides,
&inv_freqs,
&i_shape,
&i_strides,
&dst_strides,
1,
seq_len,
half_dim,
1,
)?;

Ok(theta)
}

fn slice(src: &[f32], src_strides: &Strides, start: &[usize], stop: &[usize]) -> Vec<f32> {
assert!(start.len() == stop.len());
assert!(start.len() == src_strides.rank());
start.iter().zip(stop.iter()).for_each(|(s, t)| {
assert!(s < t);
});

let dst_shape: Vec<usize> = stop.iter().zip(start.iter()).map(|(s, t)| s - t).collect();
let dst_numel: usize = dst_shape.iter().product();

let mut dst = vec![0.0; dst_numel];

for i in 0..dst_numel {
let mut src_index = 0;
let mut tmp = i;
for d in 0..dst_shape.len() {
let coord = tmp / dst_shape[d + 1..].iter().product::<usize>().max(1);
tmp %= dst_shape[d + 1..].iter().product::<usize>().max(1);
src_index += (coord + start[d]) * src_strides[d] as usize;
}
dst[i] = src[src_index];
}

dst
}

fn rope(
src: Vec<f32>,
shape: &Shape,
dim: usize,
base: f32,
offset: usize,
) -> Result<Vec<f32>, OperationError> {
let [batches, num_heads, seq_len, head_dim] = shape.try_into().unwrap();

let half_dim = dim / 2;
let theta = compute_theta(dim, seq_len, base, offset)?;
let (sin, cos): (Vec<f32>, Vec<f32>) = theta.iter().map(|i| i.sin_cos()).unzip();
let src_strides = Strides::from(shape);
let x1 = slice(
&src,
&src_strides,
&[0, 0, 0, 0],
&[batches, num_heads, seq_len, half_dim],
);
let x2 = slice(
&src,
&src_strides,
&[0, 0, 0, half_dim],
&[batches, num_heads, seq_len, dim],
);

//`multiply` as an operation that deals with broadcasting
let x1_cos = x1
.iter()
.zip(cos.iter().cycle())
.map(|(x, c)| x * c)
.collect::<Vec<f32>>();
let x2_sin = x2
.iter()
.zip(sin.iter().cycle())
.map(|(x, s)| x * s)
.collect::<Vec<f32>>();

let mut r1 = x1_cos
.iter()
.zip(x2_sin.iter())
.map(|(x1, x2)| x1 - x2)
.collect::<Vec<f32>>();
r1.extend(vec![0.0; shape.numel() - r1.len()]);

let x1_sin = x1
.iter()
.zip(sin.iter().cycle())
.map(|(x, s)| x * s)
.collect::<Vec<f32>>();
let x2_cos = x2
.iter()
.zip(cos.iter().cycle())
.map(|(x, c)| x * c)
.collect::<Vec<f32>>();
let mut r2 = x1_sin
.iter()
.zip(x2_cos.iter())
.map(|(x1, x2)| x1 + x2)
.collect::<Vec<f32>>();
r2.extend(vec![0.0; shape.numel() - r2.len()]);

let mut to_cat = vec![
(shape![batches, num_heads, seq_len, half_dim], r1),
(shape![batches, num_heads, seq_len, half_dim], r2),
];
if dim < shape[3] {
let r3 = slice(
&src,
&src_strides,
&[0, 0, 0, dim],
&[batches, num_heads, seq_len, head_dim],
);
to_cat.push((shape![batches, num_heads, seq_len, head_dim - dim], r3));
}

let dst_shape = shape![batches, num_heads, seq_len, head_dim];
let mut dst = vec![0.0f32; dst_shape.numel()];
concat(to_cat.as_slice(), 3, &dst_shape, &mut dst)?;
Ok(dst)
}
3 changes: 3 additions & 0 deletions crates/ratchet-core/src/cpu/slice.rs
ivarflakstad marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
use crate::{Slice, Tensor};

pub fn cpu_slice(op: Slice, dst: Tensor) -> Result<Tensor, OperationError> {}
6 changes: 6 additions & 0 deletions crates/ratchet-core/src/cpu/utils.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
use crate::{CPUBuffer, Storage, Tensor};
use bytemuck::NoUninit;

pub fn cpu_store_result<T: NoUninit>(dst: &Tensor, data: &[T]) {
dst.update_storage(Storage::CPU(CPUBuffer::from_slice(data, dst.shape())));
}
Loading
Loading