From ab223ca2e1eaac7d8900db398dab79add0ac07e1 Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Fri, 6 Sep 2024 13:06:15 +0200 Subject: [PATCH 01/32] initial rope setup --- crates/ratchet-core/src/cpu/gemm.rs | 2 +- crates/ratchet-core/src/cpu/mod.rs | 9 +-- crates/ratchet-core/src/cpu/rope.rs | 106 +++++++++++++++++++++++++++ crates/ratchet-core/src/cpu/utils.rs | 6 ++ crates/ratchet-core/src/ops/rope.rs | 56 +++++++++++--- crates/ratchet-core/src/tensor.rs | 10 +-- 6 files changed, 168 insertions(+), 21 deletions(-) create mode 100644 crates/ratchet-core/src/cpu/rope.rs create mode 100644 crates/ratchet-core/src/cpu/utils.rs diff --git a/crates/ratchet-core/src/cpu/gemm.rs b/crates/ratchet-core/src/cpu/gemm.rs index b0ac55cf..294829d7 100644 --- a/crates/ratchet-core/src/cpu/gemm.rs +++ b/crates/ratchet-core/src/cpu/gemm.rs @@ -1,5 +1,5 @@ use crate::{ - cpu_store_result, CPUOperation, DType, InvariantError, Matmul, MatmulSpec, OperationError, + cpu::cpu_store_result, CPUOperation, DType, InvariantError, Matmul, MatmulSpec, OperationError, Shape, Tensor, TensorDType, }; use anyhow::{anyhow, Result}; diff --git a/crates/ratchet-core/src/cpu/mod.rs b/crates/ratchet-core/src/cpu/mod.rs index 147af984..56ce8f2d 100644 --- a/crates/ratchet-core/src/cpu/mod.rs +++ b/crates/ratchet-core/src/cpu/mod.rs @@ -1,4 +1,6 @@ pub mod gemm; +pub mod rope; +mod utils; use crate::{ dequantize, Binary, BinaryOp, CPUBuffer, CPUOperation, Cast, DType, IndexSelect, @@ -6,10 +8,11 @@ use crate::{ TensorDType, Unary, UnaryOp, }; use anyhow::anyhow; -use bytemuck::NoUninit; use core::marker::PhantomData; use half::{bf16, f16}; use num_traits::Float; +use rope::*; +use utils::cpu_store_result; #[derive(Debug)] pub struct CPU { @@ -363,7 +366,3 @@ pub fn binary_apply_inplace( cpu_store_result(dst, &lhs); Ok(()) } - -pub fn cpu_store_result(dst: &Tensor, data: &[T]) { - dst.update_storage(Storage::CPU(CPUBuffer::from_slice(data, dst.shape()))); -} diff --git a/crates/ratchet-core/src/cpu/rope.rs b/crates/ratchet-core/src/cpu/rope.rs new file mode 100644 index 00000000..3eeea8a4 --- /dev/null +++ b/crates/ratchet-core/src/cpu/rope.rs @@ -0,0 +1,106 @@ +use crate::{ + cpu::cpu_store_result, DType, OperationError, RoPE, Shape, Tensor, TensorDType, TensorError, + Unary, +}; +use half::{bf16, f16}; +use num_traits::Float; + +pub fn cpu_rope(op: RoPE, dst: Tensor) -> Result { + match op.input().dt() { + DType::F32 => { + let dim = op.dim(); + let base = op.base(); + let offset = op.offset(); + let src = op.input().to_vec::()?; + let result = rope(&src, op.input().shape(), dim, base, offset); + cpu_store_result(&dst, &result) + } + _ => todo!(), + } + + Ok(dst) +} + +fn rope(src: &[f32], shape: &Shape, dim: usize, base: f32, offset: usize) -> Vec { + let [b, t, h, d] = shape.try_into().unwrap(); + let el_count = b * h * t * d; + + let src = &src[offset..offset + el_count]; + + let half_dim = dim / 2; + let positions = (offset..el_count + offset) + .map(|x| x as f32) + .collect::>(); + + let log_base = base.log2(); + let inv_freqs = (0..d) + .step_by(2) + .rev() + .map(|i| -(i as f32)) + .map(|i| i * log_base / half_dim as f32) + .map(|i| i.exp()) + .collect::>(); + + let theta = positions + .iter() + .zip(inv_freqs.iter()) + .map(|(p, i)| p * i) + .collect::>(); + + let cos = theta.iter().map(|x| x.cos()).collect::>(); + let sin = theta.iter().map(|x| x.sin()).collect::>(); + + let mut dst = vec![0.0; el_count]; + + println!("cos len: {}", cos.len()); + println!("sin len: {}", sin.len()); + println!("src len: {}", src.len()); + println!("dst len: {}", dst.len()); + + src.chunks(t * h * d) + .zip(dst.chunks_mut(t * h * d)) + .for_each(|(src, dst)| { + for i_t in 0..t { + for i_d in 0..d / 2 { + let i_cs = i_t * (d / 2) + i_d; + for i_h in 0..h { + let i1 = i_t * h * d + i_h * d + i_d; + let i2 = i1 + d / 2; + dst[i1] = src[i1] * cos[i_cs] - src[i2] * sin[i_cs]; + dst[i2] = src[i1] * sin[i_cs] + src[i2] * cos[i_cs]; + } + } + } + }); + dst +} + +fn old_rope(src: &[f32], shape: &Shape, dim: usize, base: f32, offset: usize) -> Vec { + let cos = src.iter().map(|x| x.cos()).collect::>(); + let sin = src.iter().map(|x| x.sin()).collect::>(); + + let b = *shape.get(0).unwrap(); + let t = *shape.get(1).unwrap(); + let h = *shape.get(2).unwrap(); + let d = *shape.get(3).unwrap(); + + let el_count = b * h * t * d; + let mut dst = vec![0.0; el_count]; + src.chunks(t * h * d) + .zip(dst.chunks_mut(t * h * d)) + .for_each(|(src, dst)| { + for i_t in 0..t { + for i_d in 0..d / 2 { + let i_cs = i_t * (d / 2) + i_d; + for i_h in 0..h { + let i1 = i_t * h * d + i_h * d + i_d; + let i2 = i1 + d / 2; + dst[i1] = src[i1] * cos[i_cs] - src[i2] * sin[i_cs]; + dst[i2] = src[i1] * sin[i_cs] + src[i2] * cos[i_cs]; + } + } + } + }); + + dst +} diff --git a/crates/ratchet-core/src/cpu/utils.rs b/crates/ratchet-core/src/cpu/utils.rs new file mode 100644 index 00000000..e61facde --- /dev/null +++ b/crates/ratchet-core/src/cpu/utils.rs @@ -0,0 +1,6 @@ +use crate::{CPUBuffer, Storage, Tensor}; +use bytemuck::NoUninit; + +pub fn cpu_store_result(dst: &Tensor, data: &[T]) { + dst.update_storage(Storage::CPU(CPUBuffer::from_slice(data, dst.shape()))); +} diff --git a/crates/ratchet-core/src/ops/rope.rs b/crates/ratchet-core/src/ops/rope.rs index 0d45260d..8b7069ee 100644 --- a/crates/ratchet-core/src/ops/rope.rs +++ b/crates/ratchet-core/src/ops/rope.rs @@ -21,7 +21,23 @@ pub struct RoPE { offset: usize, } -impl RoPE {} +impl RoPE { + pub fn input(&self) -> &Tensor { + &self.input + } + + pub fn dim(&self) -> usize { + self.dim + } + + pub fn base(&self) -> f32 { + self.base + } + + pub fn offset(&self) -> usize { + self.offset + } +} #[derive(Debug, derive_new::new, ShaderType, WgslMetadata)] pub struct RoPEMeta { @@ -181,7 +197,7 @@ impl Kernel for RoPEKernels { (&out_strides).into(), SL as u32, inner.offset as u32, - inner.base, + f32::log2(inner.base), 1.0, )) } @@ -273,8 +289,7 @@ def mlx_rope(input, dim, offset): run_py_prg(prg.to_string(), &[a], &[&dim, &offset], a.dt()) } - fn run_rope_trial(problem: RoPEProblem) { - let device = Device::request_device(DeviceRequest::GPU).unwrap(); + fn run_rope_trial(problem: RoPEProblem, device: Device) { let RoPEProblem { BS, NH, @@ -286,12 +301,12 @@ def mlx_rope(input, dim, offset): let a = Tensor::randn::(shape![BS, NH, SL, HD], Device::CPU); let ground = ground_truth(&a, dim, offset).unwrap(); - let a_gpu = a.to(&device).unwrap(); - let b = a_gpu.rope(dim, 10000.0, offset).unwrap().resolve().unwrap(); + let a = a.to(&device).unwrap(); + let b = a.rope(dim, 10000.0, offset).unwrap().resolve().unwrap(); let ours = b.to(&Device::CPU).unwrap(); - //println!("ours = \n{:#?}\n", ours.to_ndarray_view::()); - //println!("ground = \n{:#?}", ground.to_ndarray_view::()); + println!("ours = \n{:#?}\n", ours.to_ndarray_view::()); + println!("ground = \n{:#?}", ground.to_ndarray_view::()); //Weak tolerance because of `ffast-math` ground.all_close(&ours, 1e-3, 1e-3).unwrap(); } @@ -315,7 +330,7 @@ def mlx_rope(input, dim, offset): } #[proptest(cases = 16)] - fn test_rope(prob: RoPEProblem) { + fn test_rope_gpu(prob: RoPEProblem) { let RoPEProblem { BS, NH, @@ -328,6 +343,27 @@ def mlx_rope(input, dim, offset): "BS = {}, NH = {}, SL = {}, HD = {}, rope_dim = {}, offset = {}", BS, NH, SL, HD, dim, offset ); - run_rope_trial(prob); + + let device = Device::request_device(DeviceRequest::GPU).unwrap(); + run_rope_trial(prob, device); + } + + #[proptest(cases = 16)] + fn test_rope_cpu(prob: RoPEProblem) { + let RoPEProblem { + BS, + NH, + SL, + HD, + dim, + offset, + } = prob; + println!( + "BS = {}, NH = {}, SL = {}, HD = {}, rope_dim = {}, offset = {}", + BS, NH, SL, HD, dim, offset + ); + + let device = Device::request_device(DeviceRequest::CPU).unwrap(); + run_rope_trial(prob, device); } } diff --git a/crates/ratchet-core/src/tensor.rs b/crates/ratchet-core/src/tensor.rs index e866eabb..2943f3af 100644 --- a/crates/ratchet-core/src/tensor.rs +++ b/crates/ratchet-core/src/tensor.rs @@ -1,8 +1,8 @@ use crate::gpu::{BindGroupEntry, CpuUniform, WgpuDevice}; use crate::{ - cpu::*, ops::*, rvec, BufferSegment, CPUBuffer, CPUOperation, CompiledOp, DType, Device, - DeviceStorage, Executable, GPUBuffer, GPUOperation, InvariantError, LazyOp, Operation, - OperationError, RVec, RawCPUBuffer, Shape, Storage, Strides, TensorDType, TensorId, + cpu::rope::cpu_rope, cpu::*, ops::*, rvec, BufferSegment, CPUBuffer, CPUOperation, CompiledOp, + DType, Device, DeviceStorage, Executable, GPUBuffer, GPUOperation, InvariantError, LazyOp, + Operation, OperationError, RVec, RawCPUBuffer, Shape, Storage, Strides, TensorDType, TensorId, }; use derive_new::new; use npyz::WriterBuilder; @@ -372,7 +372,7 @@ impl Tensor { pub fn rope(self, dim: usize, base: f32, offset: usize) -> anyhow::Result { let device = self.device.clone(); - let rope = RoPE::new(self, dim, f32::log2(base), offset); + let rope = RoPE::new(self, dim, base, offset); let new_view = rope.compute_view()?; Ok(Tensor::lazy(LazyOp::RoPE(rope), new_view, device)) } @@ -745,7 +745,7 @@ impl Tensor { LazyOp::Cast(c) => cpu_cast(c, dst).ok(), LazyOp::Matmul(m) => m.apply(dst).ok(), LazyOp::Softmax(_s) => todo!(), - LazyOp::RoPE(_r) => todo!(), + LazyOp::RoPE(r) => cpu_rope(r, dst).ok(), LazyOp::Unary(u) => cpu_unary(u, dst).ok(), LazyOp::Reindex(_r) => todo!(), LazyOp::Concat(_c) => todo!(), From 2534dfbd5c5479a3cf3e42172a047927ff27f73e Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Fri, 6 Sep 2024 13:07:46 +0200 Subject: [PATCH 02/32] tidy --- crates/ratchet-core/src/cpu/rope.rs | 30 ----------------------------- 1 file changed, 30 deletions(-) diff --git a/crates/ratchet-core/src/cpu/rope.rs b/crates/ratchet-core/src/cpu/rope.rs index 3eeea8a4..83e8e544 100644 --- a/crates/ratchet-core/src/cpu/rope.rs +++ b/crates/ratchet-core/src/cpu/rope.rs @@ -74,33 +74,3 @@ fn rope(src: &[f32], shape: &Shape, dim: usize, base: f32, offset: usize) -> Vec }); dst } - -fn old_rope(src: &[f32], shape: &Shape, dim: usize, base: f32, offset: usize) -> Vec { - let cos = src.iter().map(|x| x.cos()).collect::>(); - let sin = src.iter().map(|x| x.sin()).collect::>(); - - let b = *shape.get(0).unwrap(); - let t = *shape.get(1).unwrap(); - let h = *shape.get(2).unwrap(); - let d = *shape.get(3).unwrap(); - - let el_count = b * h * t * d; - let mut dst = vec![0.0; el_count]; - src.chunks(t * h * d) - .zip(dst.chunks_mut(t * h * d)) - .for_each(|(src, dst)| { - for i_t in 0..t { - for i_d in 0..d / 2 { - let i_cs = i_t * (d / 2) + i_d; - for i_h in 0..h { - let i1 = i_t * h * d + i_h * d + i_d; - let i2 = i1 + d / 2; - dst[i1] = src[i1] * cos[i_cs] - src[i2] * sin[i_cs]; - dst[i2] = src[i1] * sin[i_cs] + src[i2] * cos[i_cs]; - } - } - } - }); - - dst -} From f7d3a32d6d05682ae04bbe43bdb8a912585d1942 Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Fri, 6 Sep 2024 14:46:55 +0200 Subject: [PATCH 03/32] chore: refactor cpu gemm for use in rope --- crates/ratchet-core/src/cpu/gemm.rs | 57 +++++++++++++++++++++-------- crates/ratchet-core/src/cpu/rope.rs | 53 +++++++++++++++++---------- 2 files changed, 76 insertions(+), 34 deletions(-) diff --git a/crates/ratchet-core/src/cpu/gemm.rs b/crates/ratchet-core/src/cpu/gemm.rs index 294829d7..fdba0622 100644 --- a/crates/ratchet-core/src/cpu/gemm.rs +++ b/crates/ratchet-core/src/cpu/gemm.rs @@ -1,10 +1,10 @@ use crate::{ cpu::cpu_store_result, CPUOperation, DType, InvariantError, Matmul, MatmulSpec, OperationError, - Shape, Tensor, TensorDType, + Shape, Strides, Tensor, TensorDType, }; use anyhow::{anyhow, Result}; use core::str::FromStr; -use gemm::{gemm, Parallelism}; +use gemm::{gemm as gemm_kernel, Parallelism}; use half::{bf16, f16}; use std::num::NonZeroUsize; @@ -56,21 +56,19 @@ fn calculate_skips( Ok((lhs_skip, rhs_skip)) } -fn gemm_impl( - spec: MatmulSpec, +pub(crate) fn gemm( lhs: &[T], + lhs_shape: &Shape, + lhs_strides: &Strides, rhs: &[T], + rhs_shape: &Shape, + rhs_strides: &Strides, + dst_strides: &Strides, + b: usize, + m: usize, + n: usize, + k: usize, ) -> Result, OperationError> { - let lhs_shape = spec.lhs_shape(); - let rhs_shape = spec.rhs_shape(); - let lhs_strides = spec.lhs_strides(); - let rhs_strides = spec.rhs_strides(); - let dst_strides = spec.dst_strides(); - let b = spec.stacks(); - let m = spec.m(); - let n = spec.n(); - let k = spec.k(); - let lhs_strides = lhs_strides.to_vec(); let rhs_strides = rhs_strides.to_vec(); let rank = lhs_shape.rank(); @@ -102,7 +100,7 @@ fn gemm_impl( let rhs_p = &rhs[step * rhs_skip..]; let dst_p = &mut dst[step * dst_skip..]; unsafe { - gemm( + gemm_kernel( m, n, k, @@ -128,6 +126,35 @@ fn gemm_impl( Ok(dst) } +fn gemm_impl( + spec: MatmulSpec, + lhs: &[T], + rhs: &[T], +) -> Result, OperationError> { + let lhs_shape = spec.lhs_shape(); + let rhs_shape = spec.rhs_shape(); + let lhs_strides = spec.lhs_strides(); + let rhs_strides = spec.rhs_strides(); + let dst_strides = spec.dst_strides(); + let b = spec.stacks(); + let m = spec.m(); + let n = spec.n(); + let k = spec.k(); + gemm( + lhs, + lhs_shape, + lhs_strides, + rhs, + rhs_shape, + rhs_strides, + dst_strides, + b, + m, + n, + k, + ) +} + impl CPUOperation for Matmul { fn apply(&self, dst: Tensor) -> Result { fn run_gemm( diff --git a/crates/ratchet-core/src/cpu/rope.rs b/crates/ratchet-core/src/cpu/rope.rs index 83e8e544..3068aeb9 100644 --- a/crates/ratchet-core/src/cpu/rope.rs +++ b/crates/ratchet-core/src/cpu/rope.rs @@ -1,6 +1,6 @@ use crate::{ - cpu::cpu_store_result, DType, OperationError, RoPE, Shape, Tensor, TensorDType, TensorError, - Unary, + cpu::{cpu_store_result, gemm::gemm}, + shape, DType, OperationError, RoPE, Shape, Strides, Tensor, TensorDType, TensorError, Unary, }; use half::{bf16, f16}; use num_traits::Float; @@ -21,19 +21,12 @@ pub fn cpu_rope(op: RoPE, dst: Tensor) -> Result { Ok(dst) } -fn rope(src: &[f32], shape: &Shape, dim: usize, base: f32, offset: usize) -> Vec { - let [b, t, h, d] = shape.try_into().unwrap(); - let el_count = b * h * t * d; - - let src = &src[offset..offset + el_count]; - +fn calculate_sincos(dim: usize, seq_len: usize, base: f32) -> (Vec, Vec) { let half_dim = dim / 2; - let positions = (offset..el_count + offset) - .map(|x| x as f32) - .collect::>(); + let positions = (0..seq_len).map(|x| x as f32).collect::>(); let log_base = base.log2(); - let inv_freqs = (0..d) + let inv_freqs = (0..dim) .step_by(2) .rev() .map(|i| -(i as f32)) @@ -41,14 +34,36 @@ fn rope(src: &[f32], shape: &Shape, dim: usize, base: f32, offset: usize) -> Vec .map(|i| i.exp()) .collect::>(); - let theta = positions - .iter() - .zip(inv_freqs.iter()) - .map(|(p, i)| p * i) - .collect::>(); + let p_shape = shape!(seq_len, 1); + let p_strides = Strides::from(&p_shape); + let i_shape = shape!(1, half_dim); + let i_strides = Strides::from(&i_shape); + let dst_strides = Strides::from(&shape!(seq_len, half_dim)); + let theta = gemm( + &positions, + &p_shape, + &p_strides, + &inv_freqs, + &i_shape, + &i_strides, + &dst_strides, + 1, + seq_len, + half_dim, + 1, + ) + .unwrap(); + + let (sin_theta, cos_theta) = theta.iter().map(|i| i.sin_cos()).unzip(); + + (sin_theta, cos_theta) +} + +fn rope(src: &[f32], shape: &Shape, dim: usize, base: f32, offset: usize) -> Vec { + let [b, t, h, d] = shape.try_into().unwrap(); + let el_count = b * h * t * d; - let cos = theta.iter().map(|x| x.cos()).collect::>(); - let sin = theta.iter().map(|x| x.sin()).collect::>(); + let (sin, cos) = calculate_sincos(dim, el_count, base); let mut dst = vec![0.0; el_count]; From 5f15257a1be0e4b79643ab439f8b6ad4dda7a15f Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Fri, 6 Sep 2024 15:21:51 +0200 Subject: [PATCH 04/32] debugging cpu RoPE --- crates/ratchet-core/src/cpu/rope.rs | 12 ++++++++---- crates/ratchet-core/src/ops/rope.rs | 16 ++++++++++++++++ 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/crates/ratchet-core/src/cpu/rope.rs b/crates/ratchet-core/src/cpu/rope.rs index 3068aeb9..a0411dfa 100644 --- a/crates/ratchet-core/src/cpu/rope.rs +++ b/crates/ratchet-core/src/cpu/rope.rs @@ -21,17 +21,19 @@ pub fn cpu_rope(op: RoPE, dst: Tensor) -> Result { Ok(dst) } -fn calculate_sincos(dim: usize, seq_len: usize, base: f32) -> (Vec, Vec) { +fn calculate_sincos(dim: usize, seq_len: usize, base: f32, offset: usize) -> (Vec, Vec) { let half_dim = dim / 2; - let positions = (0..seq_len).map(|x| x as f32).collect::>(); + let positions = (offset..seq_len + offset) + .map(|x| x as f32) + .collect::>(); let log_base = base.log2(); let inv_freqs = (0..dim) .step_by(2) .rev() .map(|i| -(i as f32)) .map(|i| i * log_base / half_dim as f32) - .map(|i| i.exp()) + .map(f32::exp) .collect::>(); let p_shape = shape!(seq_len, 1); @@ -63,7 +65,9 @@ fn rope(src: &[f32], shape: &Shape, dim: usize, base: f32, offset: usize) -> Vec let [b, t, h, d] = shape.try_into().unwrap(); let el_count = b * h * t * d; - let (sin, cos) = calculate_sincos(dim, el_count, base); + let (sin, cos) = calculate_sincos(dim, el_count, base, offset); + //let sin = &sin[offset..el_count + offset]; + //let cos = &cos[offset..el_count + offset]; let mut dst = vec![0.0; el_count]; diff --git a/crates/ratchet-core/src/ops/rope.rs b/crates/ratchet-core/src/ops/rope.rs index 8b7069ee..75dbf71d 100644 --- a/crates/ratchet-core/src/ops/rope.rs +++ b/crates/ratchet-core/src/ops/rope.rs @@ -366,4 +366,20 @@ def mlx_rope(input, dim, offset): let device = Device::request_device(DeviceRequest::CPU).unwrap(); run_rope_trial(prob, device); } + + #[test] + fn debug_rope_cpu() { + let prob = RoPEProblem { + BS: 1, + NH: 1, + SL: 2, + HD: 8, + dim: 8, + offset: 3, + }; + println!("{prob:?}"); + + let device = Device::request_device(DeviceRequest::CPU).unwrap(); + run_rope_trial(prob, device); + } } From 70dc1536a5f71746e6cff2ac6f1e364ecf99e714 Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Fri, 6 Sep 2024 17:14:30 +0200 Subject: [PATCH 05/32] Rope cpu is almost there --- crates/ratchet-core/src/cpu/rope.rs | 36 +++++++++++++++-------------- 1 file changed, 19 insertions(+), 17 deletions(-) diff --git a/crates/ratchet-core/src/cpu/rope.rs b/crates/ratchet-core/src/cpu/rope.rs index a0411dfa..739bb512 100644 --- a/crates/ratchet-core/src/cpu/rope.rs +++ b/crates/ratchet-core/src/cpu/rope.rs @@ -24,23 +24,25 @@ pub fn cpu_rope(op: RoPE, dst: Tensor) -> Result { fn calculate_sincos(dim: usize, seq_len: usize, base: f32, offset: usize) -> (Vec, Vec) { let half_dim = dim / 2; - let positions = (offset..seq_len + offset) - .map(|x| x as f32) - .collect::>(); - let log_base = base.log2(); - let inv_freqs = (0..dim) - .step_by(2) - .rev() + let p_len = seq_len + offset; + + let positions = (offset..p_len).map(|x| x as f32).collect::>(); + + let log_base = base.ln(); + let inv_freqs = (0..half_dim) .map(|i| -(i as f32)) .map(|i| i * log_base / half_dim as f32) .map(f32::exp) .collect::>(); - let p_shape = shape!(seq_len, 1); + println!("positions: {:?}", positions); + println!("inv_freqs: {:?}", inv_freqs); + + let p_shape = shape!(p_len, 1); let p_strides = Strides::from(&p_shape); let i_shape = shape!(1, half_dim); let i_strides = Strides::from(&i_shape); - let dst_strides = Strides::from(&shape!(seq_len, half_dim)); + let dst_strides = Strides::from(&shape!(p_len, half_dim)); let theta = gemm( &positions, &p_shape, @@ -56,18 +58,18 @@ fn calculate_sincos(dim: usize, seq_len: usize, base: f32, offset: usize) -> (Ve ) .unwrap(); + println!("theta: {:?}", theta); + let (sin_theta, cos_theta) = theta.iter().map(|i| i.sin_cos()).unzip(); (sin_theta, cos_theta) } fn rope(src: &[f32], shape: &Shape, dim: usize, base: f32, offset: usize) -> Vec { - let [b, t, h, d] = shape.try_into().unwrap(); - let el_count = b * h * t * d; + let [b, h, sl, d] = shape.try_into().unwrap(); + let el_count = b * h * sl * d; - let (sin, cos) = calculate_sincos(dim, el_count, base, offset); - //let sin = &sin[offset..el_count + offset]; - //let cos = &cos[offset..el_count + offset]; + let (sin, cos) = calculate_sincos(dim, sl, base, offset); let mut dst = vec![0.0; el_count]; @@ -76,10 +78,10 @@ fn rope(src: &[f32], shape: &Shape, dim: usize, base: f32, offset: usize) -> Vec println!("src len: {}", src.len()); println!("dst len: {}", dst.len()); - src.chunks(t * h * d) - .zip(dst.chunks_mut(t * h * d)) + src.chunks(sl * h * d) + .zip(dst.chunks_mut(sl * h * d)) .for_each(|(src, dst)| { - for i_t in 0..t { + for i_t in 0..sl { for i_d in 0..d / 2 { let i_cs = i_t * (d / 2) + i_d; for i_h in 0..h { From d1f31da454354d2825932d04c8f00316bf85ff48 Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Mon, 9 Sep 2024 16:56:43 +0200 Subject: [PATCH 06/32] debugging gemm/rope interaction --- crates/ratchet-core/src/cpu/rope.rs | 156 ++++++++++++++++++++++++---- crates/ratchet-core/src/ops/rope.rs | 29 +++++- 2 files changed, 163 insertions(+), 22 deletions(-) diff --git a/crates/ratchet-core/src/cpu/rope.rs b/crates/ratchet-core/src/cpu/rope.rs index 739bb512..16951863 100644 --- a/crates/ratchet-core/src/cpu/rope.rs +++ b/crates/ratchet-core/src/cpu/rope.rs @@ -65,33 +65,151 @@ fn calculate_sincos(dim: usize, seq_len: usize, base: f32, offset: usize) -> (Ve (sin_theta, cos_theta) } +#[inline] +fn split_by_offset(data: &[f32], offset: usize) -> (Vec, Vec) { + let mut x1 = Vec::with_capacity(data.len() / 2); + let mut x2 = Vec::with_capacity(data.len() / 2); + + let mut start = 0; + let mut stop = offset; + while stop < data.len() { + let mut chunk = data[start..stop].to_vec(); + x1.append(&mut chunk); + start += offset; + stop += offset; + + let mut chunk = data[start..stop].to_vec(); + x2.append(&mut chunk); + start += offset; + stop += offset; + } + (x1.to_vec(), x2.to_vec()) +} + fn rope(src: &[f32], shape: &Shape, dim: usize, base: f32, offset: usize) -> Vec { - let [b, h, sl, d] = shape.try_into().unwrap(); - let el_count = b * h * sl * d; + let [b, h, t, d] = shape.try_into().unwrap(); + let el_count = b * h * t * d; - let (sin, cos) = calculate_sincos(dim, sl, base, offset); + let (sin, cos) = calculate_sincos(dim, t, base, offset); - let mut dst = vec![0.0; el_count]; + let mut dst = Vec::with_capacity(el_count); println!("cos len: {}", cos.len()); println!("sin len: {}", sin.len()); println!("src len: {}", src.len()); println!("dst len: {}", dst.len()); - src.chunks(sl * h * d) - .zip(dst.chunks_mut(sl * h * d)) - .for_each(|(src, dst)| { - for i_t in 0..sl { - for i_d in 0..d / 2 { - let i_cs = i_t * (d / 2) + i_d; - for i_h in 0..h { - let i1 = i_t * h * d + i_h * d + i_d; - let i2 = i1 + d / 2; - dst[i1] = src[i1] * cos[i_cs] - src[i2] * sin[i_cs]; - dst[i2] = src[i1] * sin[i_cs] + src[i2] * cos[i_cs]; - } - } - } - }); + //println!("src: {:?}", src); + + let split = sin.len(); + let (x1, x2) = split_by_offset(src, split); + let x_shape = shape!(4, 1); + let x_strides = Strides::from(&x_shape); + let theta_shape = shape!(1, 4); + let theta_strides = Strides::from(&theta_shape); + + /* + a_shape = [2, 4] + [[1, 2, 3, 4], [5, 6, 7, 8]] + b_shape = [2, 4] + [[1, 2, 3, 4], [5, 6, 7, 8]] + + m = 2 + n = 2 + k = 4 + + */ + + let m = 4; + let n = 1; + let k = 4; + + println!("x_shape: {:?}", x_shape); + println!("theta_shape: {:?}", theta_shape); + + println!("m: {}", m); + println!("n: {}", n); + println!("k: {}", k); + + let x1_sin = gemm( + &x1, + &x_shape, + &x_strides, + &sin, + &theta_shape, + &theta_strides, + &x_strides, + 1, + m, + n, + k, + ) + .unwrap(); + + let x1_cos = gemm( + &x1, + &x_shape, + &x_strides, + &cos, + &theta_shape, + &theta_strides, + &x_strides, + 1, + m, + n, + k, + ) + .unwrap(); + + let x2_sin = gemm( + &x2, + &x_shape, + &x_strides, + &sin, + &theta_shape, + &theta_strides, + &x_strides, + 1, + m, + n, + k, + ) + .unwrap(); + + let x2_cos = gemm( + &x2, + &x_shape, + &x_strides, + &cos, + &theta_shape, + &theta_strides, + &x_strides, + 1, + m, + n, + k, + ) + .unwrap(); + + println!("x1: {:?}", x1); + println!("x2: {:?}", x2); + println!("sin: {:?}", sin); + println!("cos: {:?}", cos); + + println!("x1_sin: {:?}", x1_sin); + println!("x1_cos: {:?}", x1_cos); + println!("x2_sin: {:?}", x2_sin); + println!("x2_cos: {:?}", x2_cos); + + x1_cos.iter().zip(x2_sin).for_each(|(x1_cos, x2_sin)| { + dst.push(x1_cos - x2_sin); + }); + + x1_sin.iter().zip(x2_cos).for_each(|(x1_sin, x2_cos)| { + dst.push(x1_sin + x2_cos); + }); + + println!("dst: {:?}", dst); + dst } diff --git a/crates/ratchet-core/src/ops/rope.rs b/crates/ratchet-core/src/ops/rope.rs index 75dbf71d..c7cb3fef 100644 --- a/crates/ratchet-core/src/ops/rope.rs +++ b/crates/ratchet-core/src/ops/rope.rs @@ -271,7 +271,7 @@ mod tests { use test_strategy::{proptest, Arbitrary}; use crate::test_util::run_py_prg; - use crate::{shape, Device, DeviceRequest, Tensor}; + use crate::{gemm, shape, Device, DeviceRequest, Shape, Strides, Tensor}; fn ground_truth(a: &Tensor, dim: usize, offset: usize) -> anyhow::Result { let prg = r#" @@ -279,6 +279,8 @@ import mlx.core as mx import mlx.nn as nn import numpy as np +mx.set_default_device(mx.cpu) + def mlx_rope(input, dim, offset): rope = nn.RoPE(dim) mx_input = mx.array(input) @@ -372,14 +374,35 @@ def mlx_rope(input, dim, offset): let prob = RoPEProblem { BS: 1, NH: 1, - SL: 2, + SL: 1, HD: 8, dim: 8, - offset: 3, + offset: 0, }; println!("{prob:?}"); let device = Device::request_device(DeviceRequest::CPU).unwrap(); run_rope_trial(prob, device); } + + #[test] + fn im_confused() { + let a = vec![1.0, 2.0, 3.0, 4.0]; + let a_s = shape!(4, 1); + let a_strides = Strides::from(&a_s); + let b = vec![1.0, 1.0, 1.0, 1.0]; + let b_s = shape!(1, 4); + let b_strides = Strides::from(&b_s); + + let m = 4; + let n = 4; + let k = 1; + + let result = gemm::gemm( + &a, &a_s, &a_strides, &b, &b_s, &b_strides, &b_strides, 1, m, n, k, + ) + .unwrap(); + + println!("{:?}", result); + } } From 2366e9d3f387113f499ad65105420d56607a9ffc Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Mon, 9 Sep 2024 17:09:47 +0200 Subject: [PATCH 07/32] More debugging gemm/rope --- crates/ratchet-core/src/cpu/rope.rs | 22 +++++----------------- 1 file changed, 5 insertions(+), 17 deletions(-) diff --git a/crates/ratchet-core/src/cpu/rope.rs b/crates/ratchet-core/src/cpu/rope.rs index 16951863..6abff1c5 100644 --- a/crates/ratchet-core/src/cpu/rope.rs +++ b/crates/ratchet-core/src/cpu/rope.rs @@ -103,26 +103,14 @@ fn rope(src: &[f32], shape: &Shape, dim: usize, base: f32, offset: usize) -> Vec let split = sin.len(); let (x1, x2) = split_by_offset(src, split); - let x_shape = shape!(4, 1); + let x_shape = shape!(h, t, d / 2); let x_strides = Strides::from(&x_shape); - let theta_shape = shape!(1, 4); + let theta_shape = shape!(1, t, d / 2); let theta_strides = Strides::from(&theta_shape); - /* - a_shape = [2, 4] - [[1, 2, 3, 4], [5, 6, 7, 8]] - b_shape = [2, 4] - [[1, 2, 3, 4], [5, 6, 7, 8]] - - m = 2 - n = 2 - k = 4 - - */ - - let m = 4; - let n = 1; - let k = 4; + let m = h * t * d / 2; + let n = t * d / 2; + let k = 1; println!("x_shape: {:?}", x_shape); println!("theta_shape: {:?}", theta_shape); From 9b4bd6c20b92c692b6d3c1fcc3ee0cc42c0e40cf Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Mon, 9 Sep 2024 17:52:30 +0200 Subject: [PATCH 08/32] Revert gemm back tobinary mul --- crates/ratchet-core/src/cpu/rope.rs | 90 +++++------------------------ 1 file changed, 14 insertions(+), 76 deletions(-) diff --git a/crates/ratchet-core/src/cpu/rope.rs b/crates/ratchet-core/src/cpu/rope.rs index 6abff1c5..875e999c 100644 --- a/crates/ratchet-core/src/cpu/rope.rs +++ b/crates/ratchet-core/src/cpu/rope.rs @@ -87,6 +87,7 @@ fn split_by_offset(data: &[f32], offset: usize) -> (Vec, Vec) { } fn rope(src: &[f32], shape: &Shape, dim: usize, base: f32, offset: usize) -> Vec { + println!("Ratchet RoPE"); let [b, h, t, d] = shape.try_into().unwrap(); let el_count = b * h * t * d; @@ -99,85 +100,22 @@ fn rope(src: &[f32], shape: &Shape, dim: usize, base: f32, offset: usize) -> Vec println!("src len: {}", src.len()); println!("dst len: {}", dst.len()); - //println!("src: {:?}", src); - let split = sin.len(); let (x1, x2) = split_by_offset(src, split); - let x_shape = shape!(h, t, d / 2); - let x_strides = Strides::from(&x_shape); - let theta_shape = shape!(1, t, d / 2); - let theta_strides = Strides::from(&theta_shape); - - let m = h * t * d / 2; - let n = t * d / 2; - let k = 1; - - println!("x_shape: {:?}", x_shape); - println!("theta_shape: {:?}", theta_shape); - - println!("m: {}", m); - println!("n: {}", n); - println!("k: {}", k); - - let x1_sin = gemm( - &x1, - &x_shape, - &x_strides, - &sin, - &theta_shape, - &theta_strides, - &x_strides, - 1, - m, - n, - k, - ) - .unwrap(); - - let x1_cos = gemm( - &x1, - &x_shape, - &x_strides, - &cos, - &theta_shape, - &theta_strides, - &x_strides, - 1, - m, - n, - k, - ) - .unwrap(); - let x2_sin = gemm( - &x2, - &x_shape, - &x_strides, - &sin, - &theta_shape, - &theta_strides, - &x_strides, - 1, - m, - n, - k, - ) - .unwrap(); - - let x2_cos = gemm( - &x2, - &x_shape, - &x_strides, - &cos, - &theta_shape, - &theta_strides, - &x_strides, - 1, - m, - n, - k, - ) - .unwrap(); + let (x1_cos, x2_cos): (Vec, Vec) = cos + .iter() + .zip(x1.iter()) + .zip(x2.iter()) + .map(|((c, x1), x2)| (c * x1, c * x2)) + .unzip(); + + let (x1_sin, x2_sin): (Vec, Vec) = sin + .iter() + .zip(x1.iter()) + .zip(x2.iter()) + .map(|((s, x1), x2)| (s * x1, s * x2)) + .unzip(); println!("x1: {:?}", x1); println!("x2: {:?}", x2); From f998176fa7f0d5635eaae33e7d41ae729f335f6f Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Mon, 9 Sep 2024 17:54:46 +0200 Subject: [PATCH 09/32] chore: turns out while I was confused, it was not about gemm --- crates/ratchet-core/src/ops/rope.rs | 21 --------------------- 1 file changed, 21 deletions(-) diff --git a/crates/ratchet-core/src/ops/rope.rs b/crates/ratchet-core/src/ops/rope.rs index c7cb3fef..2c5ffb39 100644 --- a/crates/ratchet-core/src/ops/rope.rs +++ b/crates/ratchet-core/src/ops/rope.rs @@ -384,25 +384,4 @@ def mlx_rope(input, dim, offset): let device = Device::request_device(DeviceRequest::CPU).unwrap(); run_rope_trial(prob, device); } - - #[test] - fn im_confused() { - let a = vec![1.0, 2.0, 3.0, 4.0]; - let a_s = shape!(4, 1); - let a_strides = Strides::from(&a_s); - let b = vec![1.0, 1.0, 1.0, 1.0]; - let b_s = shape!(1, 4); - let b_strides = Strides::from(&b_s); - - let m = 4; - let n = 4; - let k = 1; - - let result = gemm::gemm( - &a, &a_s, &a_strides, &b, &b_s, &b_strides, &b_strides, 1, m, n, k, - ) - .unwrap(); - - println!("{:?}", result); - } } From a09f9869f3cbbb9513b40773105508a5a5d2f204 Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Mon, 9 Sep 2024 18:28:45 +0200 Subject: [PATCH 10/32] chore: Add interleave_by_offset --- crates/ratchet-core/src/cpu/rope.rs | 40 ++++++++++++++++++++++------- crates/ratchet-core/src/ops/rope.rs | 2 +- 2 files changed, 32 insertions(+), 10 deletions(-) diff --git a/crates/ratchet-core/src/cpu/rope.rs b/crates/ratchet-core/src/cpu/rope.rs index 875e999c..8c6aafb6 100644 --- a/crates/ratchet-core/src/cpu/rope.rs +++ b/crates/ratchet-core/src/cpu/rope.rs @@ -66,7 +66,7 @@ fn calculate_sincos(dim: usize, seq_len: usize, base: f32, offset: usize) -> (Ve } #[inline] -fn split_by_offset(data: &[f32], offset: usize) -> (Vec, Vec) { +fn chunk_by_offset(data: &[f32], offset: usize) -> (Vec, Vec) { let mut x1 = Vec::with_capacity(data.len() / 2); let mut x2 = Vec::with_capacity(data.len() / 2); @@ -86,22 +86,43 @@ fn split_by_offset(data: &[f32], offset: usize) -> (Vec, Vec) { (x1.to_vec(), x2.to_vec()) } +#[inline] +fn interleave_by_offset(data: &[f32], offset: usize) -> Vec { + let n = data.len(); + let mid = n / 2; + let mut interleaved = Vec::with_capacity(n); + + let mut start = 0; + let mut stop = offset; + while stop <= mid { + let mut chunk = data[start..stop].to_vec(); + interleaved.append(&mut chunk); + + let mut chunk = data[start + mid..stop + mid].to_vec(); + interleaved.append(&mut chunk); + + start += offset; + stop += offset; + } + interleaved +} + fn rope(src: &[f32], shape: &Shape, dim: usize, base: f32, offset: usize) -> Vec { println!("Ratchet RoPE"); let [b, h, t, d] = shape.try_into().unwrap(); let el_count = b * h * t * d; let (sin, cos) = calculate_sincos(dim, t, base, offset); - - let mut dst = Vec::with_capacity(el_count); + let mut intermediate = Vec::with_capacity(el_count); println!("cos len: {}", cos.len()); println!("sin len: {}", sin.len()); println!("src len: {}", src.len()); - println!("dst len: {}", dst.len()); - let split = sin.len(); - let (x1, x2) = split_by_offset(src, split); + let offset = el_count / t / 2; + + println!("offset: {}", offset); + let (x1, x2) = chunk_by_offset(src, offset); let (x1_cos, x2_cos): (Vec, Vec) = cos .iter() @@ -128,14 +149,15 @@ fn rope(src: &[f32], shape: &Shape, dim: usize, base: f32, offset: usize) -> Vec println!("x2_cos: {:?}", x2_cos); x1_cos.iter().zip(x2_sin).for_each(|(x1_cos, x2_sin)| { - dst.push(x1_cos - x2_sin); + intermediate.push(x1_cos - x2_sin); }); x1_sin.iter().zip(x2_cos).for_each(|(x1_sin, x2_cos)| { - dst.push(x1_sin + x2_cos); + intermediate.push(x1_sin + x2_cos); }); + println!("intermediate: {:?}", intermediate); + let dst = interleave_by_offset(&intermediate, offset); println!("dst: {:?}", dst); - dst } diff --git a/crates/ratchet-core/src/ops/rope.rs b/crates/ratchet-core/src/ops/rope.rs index 2c5ffb39..9f2192f5 100644 --- a/crates/ratchet-core/src/ops/rope.rs +++ b/crates/ratchet-core/src/ops/rope.rs @@ -374,7 +374,7 @@ def mlx_rope(input, dim, offset): let prob = RoPEProblem { BS: 1, NH: 1, - SL: 1, + SL: 2, HD: 8, dim: 8, offset: 0, From 46861fa1d912a153840431e35a588aec79784368 Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Mon, 9 Sep 2024 19:04:30 +0200 Subject: [PATCH 11/32] close --- crates/ratchet-core/src/cpu/rope.rs | 20 +++++++++----------- crates/ratchet-core/src/ops/rope.rs | 2 +- 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/crates/ratchet-core/src/cpu/rope.rs b/crates/ratchet-core/src/cpu/rope.rs index 8c6aafb6..ce597e74 100644 --- a/crates/ratchet-core/src/cpu/rope.rs +++ b/crates/ratchet-core/src/cpu/rope.rs @@ -58,8 +58,6 @@ fn calculate_sincos(dim: usize, seq_len: usize, base: f32, offset: usize) -> (Ve ) .unwrap(); - println!("theta: {:?}", theta); - let (sin_theta, cos_theta) = theta.iter().map(|i| i.sin_cos()).unzip(); (sin_theta, cos_theta) @@ -94,7 +92,7 @@ fn interleave_by_offset(data: &[f32], offset: usize) -> Vec { let mut start = 0; let mut stop = offset; - while stop <= mid { + while stop + mid <= n { let mut chunk = data[start..stop].to_vec(); interleaved.append(&mut chunk); @@ -112,6 +110,7 @@ fn rope(src: &[f32], shape: &Shape, dim: usize, base: f32, offset: usize) -> Vec let [b, h, t, d] = shape.try_into().unwrap(); let el_count = b * h * t * d; + let half_dim = dim / 2; let (sin, cos) = calculate_sincos(dim, t, base, offset); let mut intermediate = Vec::with_capacity(el_count); @@ -124,18 +123,17 @@ fn rope(src: &[f32], shape: &Shape, dim: usize, base: f32, offset: usize) -> Vec println!("offset: {}", offset); let (x1, x2) = chunk_by_offset(src, offset); - let (x1_cos, x2_cos): (Vec, Vec) = cos + let N = sin.len(); + let (x1_cos, x1_sin): (Vec, Vec) = x1 .iter() - .zip(x1.iter()) - .zip(x2.iter()) - .map(|((c, x1), x2)| (c * x1, c * x2)) + .enumerate() + .map(|(i, x)| (x * cos[i % N], x * sin[i % N])) .unzip(); - let (x1_sin, x2_sin): (Vec, Vec) = sin + let (x2_cos, x2_sin): (Vec, Vec) = x2 .iter() - .zip(x1.iter()) - .zip(x2.iter()) - .map(|((s, x1), x2)| (s * x1, s * x2)) + .enumerate() + .map(|(i, x)| (x * cos[i % N], x * sin[i % N])) .unzip(); println!("x1: {:?}", x1); diff --git a/crates/ratchet-core/src/ops/rope.rs b/crates/ratchet-core/src/ops/rope.rs index 9f2192f5..abb7257d 100644 --- a/crates/ratchet-core/src/ops/rope.rs +++ b/crates/ratchet-core/src/ops/rope.rs @@ -373,7 +373,7 @@ def mlx_rope(input, dim, offset): fn debug_rope_cpu() { let prob = RoPEProblem { BS: 1, - NH: 1, + NH: 2, SL: 2, HD: 8, dim: 8, From 23f268764b0089febcde554b39c5b84aac3a644d Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Mon, 9 Sep 2024 19:13:45 +0200 Subject: [PATCH 12/32] Most RoPE test cases are passing --- crates/ratchet-core/src/cpu/rope.rs | 3 +-- crates/ratchet-core/src/ops/rope.rs | 4 ++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/crates/ratchet-core/src/cpu/rope.rs b/crates/ratchet-core/src/cpu/rope.rs index ce597e74..ad9390af 100644 --- a/crates/ratchet-core/src/cpu/rope.rs +++ b/crates/ratchet-core/src/cpu/rope.rs @@ -110,7 +110,6 @@ fn rope(src: &[f32], shape: &Shape, dim: usize, base: f32, offset: usize) -> Vec let [b, h, t, d] = shape.try_into().unwrap(); let el_count = b * h * t * d; - let half_dim = dim / 2; let (sin, cos) = calculate_sincos(dim, t, base, offset); let mut intermediate = Vec::with_capacity(el_count); @@ -118,7 +117,7 @@ fn rope(src: &[f32], shape: &Shape, dim: usize, base: f32, offset: usize) -> Vec println!("sin len: {}", sin.len()); println!("src len: {}", src.len()); - let offset = el_count / t / 2; + let offset = el_count / b / h / t / 2; println!("offset: {}", offset); let (x1, x2) = chunk_by_offset(src, offset); diff --git a/crates/ratchet-core/src/ops/rope.rs b/crates/ratchet-core/src/ops/rope.rs index abb7257d..184addd5 100644 --- a/crates/ratchet-core/src/ops/rope.rs +++ b/crates/ratchet-core/src/ops/rope.rs @@ -372,12 +372,12 @@ def mlx_rope(input, dim, offset): #[test] fn debug_rope_cpu() { let prob = RoPEProblem { - BS: 1, + BS: 2, NH: 2, SL: 2, HD: 8, dim: 8, - offset: 0, + offset: 5, }; println!("{prob:?}"); From 328d8ce2f4104248e57ec23e44fcde20f45a0326 Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Wed, 18 Sep 2024 12:07:05 +0200 Subject: [PATCH 13/32] getting there --- crates/ratchet-core/src/cpu/rope.rs | 54 +++++++++++++------ crates/ratchet-core/src/ops/rope.rs | 15 +++--- crates/ratchet-core/src/storage/cpu_buffer.rs | 2 +- 3 files changed, 48 insertions(+), 23 deletions(-) diff --git a/crates/ratchet-core/src/cpu/rope.rs b/crates/ratchet-core/src/cpu/rope.rs index ad9390af..3ad5f1e5 100644 --- a/crates/ratchet-core/src/cpu/rope.rs +++ b/crates/ratchet-core/src/cpu/rope.rs @@ -1,3 +1,5 @@ +use std::num::NonZero; + use crate::{ cpu::{cpu_store_result, gemm::gemm}, shape, DType, OperationError, RoPE, Shape, Strides, Tensor, TensorDType, TensorError, Unary, @@ -38,11 +40,11 @@ fn calculate_sincos(dim: usize, seq_len: usize, base: f32, offset: usize) -> (Ve println!("positions: {:?}", positions); println!("inv_freqs: {:?}", inv_freqs); - let p_shape = shape!(p_len, 1); + let p_shape = shape!(seq_len, 1); let p_strides = Strides::from(&p_shape); let i_shape = shape!(1, half_dim); let i_strides = Strides::from(&i_shape); - let dst_strides = Strides::from(&shape!(p_len, half_dim)); + let dst_strides = Strides::from(&shape!(seq_len, half_dim)); let theta = gemm( &positions, &p_shape, @@ -64,7 +66,7 @@ fn calculate_sincos(dim: usize, seq_len: usize, base: f32, offset: usize) -> (Ve } #[inline] -fn chunk_by_offset(data: &[f32], offset: usize) -> (Vec, Vec) { +fn chunk_by_offset(data: &[f32], offset: usize, skip: usize) -> (Vec, Vec) { let mut x1 = Vec::with_capacity(data.len() / 2); let mut x2 = Vec::with_capacity(data.len() / 2); @@ -80,12 +82,15 @@ fn chunk_by_offset(data: &[f32], offset: usize) -> (Vec, Vec) { x2.append(&mut chunk); start += offset; stop += offset; + + start += skip; + stop += skip; } (x1.to_vec(), x2.to_vec()) } #[inline] -fn interleave_by_offset(data: &[f32], offset: usize) -> Vec { +fn merge(data: &[f32], offset: usize, skip: usize) -> Vec { let n = data.len(); let mid = n / 2; let mut interleaved = Vec::with_capacity(n); @@ -101,38 +106,44 @@ fn interleave_by_offset(data: &[f32], offset: usize) -> Vec { start += offset; stop += offset; + + start += skip; + stop += skip; } interleaved } fn rope(src: &[f32], shape: &Shape, dim: usize, base: f32, offset: usize) -> Vec { println!("Ratchet RoPE"); - let [b, h, t, d] = shape.try_into().unwrap(); - let el_count = b * h * t * d; + let [batches, num_heads, seq_len, head_dim] = shape.try_into().unwrap(); + let el_count = batches * num_heads * seq_len * head_dim; - let (sin, cos) = calculate_sincos(dim, t, base, offset); + let half_dim = dim / 2; + let (sin, cos) = calculate_sincos(dim, seq_len, base, offset); let mut intermediate = Vec::with_capacity(el_count); + let chunk_offset = half_dim; + let skip = 0; + + println!("chunk_offset: {}", chunk_offset); + let (x1, x2) = chunk_by_offset(src, chunk_offset, skip); + println!("cos len: {}", cos.len()); println!("sin len: {}", sin.len()); println!("src len: {}", src.len()); + println!("x1 len: {}", x1.len()); + println!("x2 len: {}", x2.len()); - let offset = el_count / b / h / t / 2; - - println!("offset: {}", offset); - let (x1, x2) = chunk_by_offset(src, offset); - - let N = sin.len(); let (x1_cos, x1_sin): (Vec, Vec) = x1 .iter() .enumerate() - .map(|(i, x)| (x * cos[i % N], x * sin[i % N])) + .map(|(i, x)| (x * cos[i % cos.len()], x * sin[i % sin.len()])) .unzip(); let (x2_cos, x2_sin): (Vec, Vec) = x2 .iter() .enumerate() - .map(|(i, x)| (x * cos[i % N], x * sin[i % N])) + .map(|(i, x)| (x * cos[i % cos.len()], x * sin[i % sin.len()])) .unzip(); println!("x1: {:?}", x1); @@ -154,7 +165,18 @@ fn rope(src: &[f32], shape: &Shape, dim: usize, base: f32, offset: usize) -> Vec }); println!("intermediate: {:?}", intermediate); - let dst = interleave_by_offset(&intermediate, offset); + println!("intermediate len: {}", intermediate.len()); + + let out_shape = shape!(batches, num_heads, seq_len, head_dim); + println!("out_shape: {:?}", out_shape); + + let skip = head_dim.abs_diff(dim); + let mut dst = merge(&intermediate, chunk_offset, skip); + + if dim < head_dim { + dst.append(&mut src[dim..].to_vec()) + } println!("dst: {:?}", dst); + println!("dst len: {}", dst.len()); dst } diff --git a/crates/ratchet-core/src/ops/rope.rs b/crates/ratchet-core/src/ops/rope.rs index 184addd5..e5ff11f8 100644 --- a/crates/ratchet-core/src/ops/rope.rs +++ b/crates/ratchet-core/src/ops/rope.rs @@ -300,7 +300,10 @@ def mlx_rope(input, dim, offset): dim, offset, } = problem; - let a = Tensor::randn::(shape![BS, NH, SL, HD], Device::CPU); + let shape = shape![BS, NH, SL, HD]; + let n = shape.numel(); + let data = (0..n).map(|x| x as f32).collect::>(); + let a = Tensor::from_data(data, shape, Device::CPU); let ground = ground_truth(&a, dim, offset).unwrap(); let a = a.to(&device).unwrap(); @@ -372,12 +375,12 @@ def mlx_rope(input, dim, offset): #[test] fn debug_rope_cpu() { let prob = RoPEProblem { - BS: 2, + BS: 1, NH: 2, - SL: 2, - HD: 8, - dim: 8, - offset: 5, + SL: 1, + HD: 32, + dim: 16, + offset: 1, }; println!("{prob:?}"); diff --git a/crates/ratchet-core/src/storage/cpu_buffer.rs b/crates/ratchet-core/src/storage/cpu_buffer.rs index 3be6f1ee..4e7fde9a 100644 --- a/crates/ratchet-core/src/storage/cpu_buffer.rs +++ b/crates/ratchet-core/src/storage/cpu_buffer.rs @@ -89,7 +89,7 @@ impl CPUBuffer { } pub fn from_slice(data: &[T], shape: &Shape) -> Self { - assert_eq!(data.len(), shape.numel()); + //assert_eq!(data.len(), shape.numel()); let bytes: &[u8] = bytemuck::cast_slice(data); Self::from_bytes(bytes, std::mem::align_of::()) } From 33b097ea929109401bca0261ad1955db79ca8237 Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Tue, 1 Oct 2024 12:03:42 +0200 Subject: [PATCH 14/32] testing a bunch of different things. really messy :) --- crates/ratchet-core/src/cpu/mod.rs | 129 +++++++++++++++--- crates/ratchet-core/src/cpu/rope.rs | 155 +++++++++++++++++----- crates/ratchet-core/src/ops/matmul/mod.rs | 4 +- crates/ratchet-core/src/ops/rope.rs | 8 +- crates/ratchet-core/src/strides.rs | 11 ++ 5 files changed, 249 insertions(+), 58 deletions(-) diff --git a/crates/ratchet-core/src/cpu/mod.rs b/crates/ratchet-core/src/cpu/mod.rs index 0e856826..6d5182ff 100644 --- a/crates/ratchet-core/src/cpu/mod.rs +++ b/crates/ratchet-core/src/cpu/mod.rs @@ -4,8 +4,8 @@ mod utils; use crate::{ dequantize, Binary, BinaryOp, CPUBuffer, CPUOperation, Cast, Concat, DType, IndexSelect, - InvariantError, OpGuards, Operation, OperationError, RVec, Storage, StorageView, Tensor, - TensorDType, Unary, UnaryOp, + InvariantError, OpGuards, Operation, OperationError, RVec, Shape, Storage, StorageView, + Strides, Tensor, TensorDType, Unary, UnaryOp, }; use anyhow::anyhow; use core::marker::PhantomData; @@ -53,6 +53,77 @@ impl Operation for CPU { } } +pub struct StridedIterator<'a> { + shape: &'a Shape, + strides: &'a Strides, + next_index: Option, + multi_index: Vec, +} + +impl<'a> StridedIterator<'a> { + pub fn new(shape: &'a Shape, strides: &'a Strides, start_offset: usize) -> Self { + Self { + shape, + strides, + next_index: if shape.numel() == 0 { + None + } else { + Some(start_offset) + }, + multi_index: vec![0; shape.len()], + } + } +} + +impl<'a> Iterator for StridedIterator<'a> { + type Item = usize; + + fn next(&mut self) -> Option { + let storage_index = match self.next_index { + None => return None, + Some(storage_index) => storage_index, + }; + let mut updated = false; + let mut next_storage_index = storage_index; + for ((multi_i, max_i), stride_i) in self + .multi_index + .iter_mut() + .zip(self.shape.iter()) + .zip(self.strides.iter()) + .rev() + { + let next_i = *multi_i + 1; + if next_i < *max_i { + *multi_i = next_i; + updated = true; + next_storage_index += *stride_i as usize; + break; + } else { + next_storage_index -= *multi_i * *stride_i as usize; + *multi_i = 0 + } + } + self.next_index = if updated { + Some(next_storage_index) + } else { + None + }; + Some(storage_index) + } +} + +impl<'a> From<(&'a Shape, &'a Strides)> for StridedIterator<'a> { + fn from((shape, strides): (&'a Shape, &'a Strides)) -> Self { + StridedIterator::new(shape, strides, 0) + } +} + +impl<'a> From<(&'a Shape, &'a Strides, usize)> for StridedIterator<'a> { + fn from((shape, strides, offset): (&'a Shape, &'a Strides, usize)) -> Self { + StridedIterator::new(shape, strides, offset) + } +} + macro_rules! impl_cpu_unary_op { ($method_name:ident, $op:expr) => { fn $method_name(input: &Tensor, dst: Tensor) -> Result { @@ -290,44 +361,58 @@ pub fn cpu_cast(cast: Cast, dst: Tensor) -> Result { Ok(dst) } -fn concat_inner( - inputs: RVec, +pub(crate) fn concat( + inputs: &[(&Shape, Vec)], dim: usize, - dst: Tensor, -) -> Result { - let dst_size = dst.shape().clone().product(); - let mut result = vec![T::zero(); dst_size]; - - let dst_dim_len = dst.shape()[dim]; - let block: usize = dst.shape().iter().skip(1 + dim).product(); + dst_shape: &Shape, + dst: &mut [T], +) -> Result<(), OperationError> { + let dst_dim_len = dst_shape[dim]; + let block: usize = dst_shape.iter().skip(1 + dim).product(); let dst_s = block * dst_dim_len; let src_o = 0; let mut dst_o = 0; - for t in inputs { - let src = t.to_vec::()?; - - let t_dims = t.shape().as_slice(); - let a_dim: usize = t_dims.iter().take(dim).product(); - let b_dim = block * t_dims[dim]; + for (src_s, src) in inputs { + let a_dim: usize = src_s.iter().take(dim).product(); + let b_dim = block * src_s[dim]; for idx in 0..a_dim { let dst_idx = idx * dst_s + dst_o; let src_idx = idx * b_dim + src_o; - let dst = &mut result[dst_idx..dst_idx + b_dim]; + let dst_t = &mut dst[dst_idx..dst_idx + b_dim]; let src = &src[src_idx..src_idx + b_dim]; - dst.copy_from_slice(src) + dst_t.copy_from_slice(src) } dst_o += b_dim; } + Ok(()) +} +pub(crate) fn apply_concat( + inputs: RVec, + dim: usize, + dst: Tensor, +) -> Result { + let dst_size = dst.shape().numel(); + let mut result = vec![T::zero(); dst_size]; + + let inputs = inputs + .iter() + .map(|t| match t.to_vec::() { + Ok(v) => Ok((t.shape(), v)), + Err(e) => Err(e.into()), + }) + .collect::, OperationError>>(); + + concat::(&inputs?, dim, dst.shape(), &mut result)?; cpu_store_result(&dst, &result); Ok(dst) } pub fn cpu_concat(Concat { inputs, dim }: Concat, dst: Tensor) -> Result { match dst.dt() { - DType::F32 => concat_inner::(inputs, dim, dst), - DType::F16 => concat_inner::(inputs, dim, dst), - DType::BF16 => concat_inner::(inputs, dim, dst), + DType::F32 => apply_concat::(inputs, dim, dst), + DType::F16 => apply_concat::(inputs, dim, dst), + DType::BF16 => apply_concat::(inputs, dim, dst), dtype => Err(InvariantError::UnsupportedDType(dtype).into()), } } diff --git a/crates/ratchet-core/src/cpu/rope.rs b/crates/ratchet-core/src/cpu/rope.rs index 3ad5f1e5..5b37954a 100644 --- a/crates/ratchet-core/src/cpu/rope.rs +++ b/crates/ratchet-core/src/cpu/rope.rs @@ -1,11 +1,9 @@ -use std::num::NonZero; - use crate::{ cpu::{cpu_store_result, gemm::gemm}, - shape, DType, OperationError, RoPE, Shape, Strides, Tensor, TensorDType, TensorError, Unary, + shape, DType, OperationError, RoPE, Shape, StridedIterator, Strides, Tensor, TensorDType, + TensorError, Unary, }; -use half::{bf16, f16}; -use num_traits::Float; +use anyhow::anyhow; pub fn cpu_rope(op: RoPE, dst: Tensor) -> Result { match op.input().dt() { @@ -14,7 +12,7 @@ pub fn cpu_rope(op: RoPE, dst: Tensor) -> Result { let base = op.base(); let offset = op.offset(); let src = op.input().to_vec::()?; - let result = rope(&src, op.input().shape(), dim, base, offset); + let result = rope(src, op.input().shape(), dim, base, offset); cpu_store_result(&dst, &result) } _ => todo!(), @@ -26,9 +24,9 @@ pub fn cpu_rope(op: RoPE, dst: Tensor) -> Result { fn calculate_sincos(dim: usize, seq_len: usize, base: f32, offset: usize) -> (Vec, Vec) { let half_dim = dim / 2; - let p_len = seq_len + offset; - - let positions = (offset..p_len).map(|x| x as f32).collect::>(); + let positions = (offset..seq_len + offset) + .map(|x| x as f32) + .collect::>(); let log_base = base.ln(); let inv_freqs = (0..half_dim) @@ -37,14 +35,11 @@ fn calculate_sincos(dim: usize, seq_len: usize, base: f32, offset: usize) -> (Ve .map(f32::exp) .collect::>(); - println!("positions: {:?}", positions); - println!("inv_freqs: {:?}", inv_freqs); - - let p_shape = shape!(seq_len, 1); + let p_shape = shape!(half_dim, 1); let p_strides = Strides::from(&p_shape); let i_shape = shape!(1, half_dim); let i_strides = Strides::from(&i_shape); - let dst_strides = Strides::from(&shape!(seq_len, half_dim)); + let dst_strides = Strides::from(&shape!(half_dim, half_dim)); let theta = gemm( &positions, &p_shape, @@ -54,14 +49,13 @@ fn calculate_sincos(dim: usize, seq_len: usize, base: f32, offset: usize) -> (Ve &i_strides, &dst_strides, 1, - seq_len, + half_dim, half_dim, 1, ) .unwrap(); let (sin_theta, cos_theta) = theta.iter().map(|i| i.sin_cos()).unzip(); - (sin_theta, cos_theta) } @@ -113,20 +107,91 @@ fn merge(data: &[f32], offset: usize, skip: usize) -> Vec { interleaved } -fn rope(src: &[f32], shape: &Shape, dim: usize, base: f32, offset: usize) -> Vec { +fn slice(src: &[f32], start: &[usize], stop: &[usize]) -> Vec { + let stop_numel: usize = stop.iter().product(); + let start_numel: usize = stop.iter().product(); + assert!(stop_numel >= start_numel); + + let mut dst = vec![0.0; stop_numel - start_numel]; + + /* + start: [0, 0, 0, 8] + stop: [1, 1, 1, 16] + for + */ + + let mut src_idx = 0; + let mut dst_idx = 0; + for i in 0..start.len() { + let mut src_stride = start[i]; + let mut dst_stride = 0; + while src_stride < stop[i] { + dst[dst_idx] = src[src_idx]; + src_idx += src_stride; + dst_idx += dst_stride; + src_stride += 1; + dst_stride += 1; + } + } + + dst +} + +// Generic transpose function +fn transpose( + src: Vec, + shape: &Shape, + dim1: usize, + dim2: usize, +) -> Result, OperationError> { + let rank = shape.rank(); + if dim1 == dim2 { + return Ok(src); + } + if rank <= dim1 || rank <= dim2 { + return Err(anyhow!("Invalid dimensions for transpose operation").into()); + } + let mut dims = shape.to_vec(); + let mut strides = Strides::from(shape).to_vec(); + println!("dims: {:?}", dims); + println!("strides: {:?}", strides); + dims.swap(dim1, dim2); + strides.swap(dim1, dim2); + println!("dims: {:?}", dims); + println!("strides: {:?}", strides); + + let shape_t = Shape::from(dims); + let strides_t = Strides::from(strides); + + let mut result = vec![0.0; src.len()]; + let strided_iter = StridedIterator::new(&shape_t, &strides_t, 0); + let strided_iter2 = StridedIterator::new(&shape_t, &strides_t, 0); + let indices = strided_iter2.collect::>(); + println!("indices: {:?}", indices); + for (index, dst_index) in strided_iter.enumerate() { + result[dst_index] = src[index]; + } + + Ok(result) +} + +fn rope(src: Vec, shape: &Shape, dim: usize, base: f32, offset: usize) -> Vec { println!("Ratchet RoPE"); let [batches, num_heads, seq_len, head_dim] = shape.try_into().unwrap(); let el_count = batches * num_heads * seq_len * head_dim; let half_dim = dim / 2; let (sin, cos) = calculate_sincos(dim, seq_len, base, offset); + + println!("cos: {:?}", cos); + println!("sin: {:?}", sin); let mut intermediate = Vec::with_capacity(el_count); let chunk_offset = half_dim; let skip = 0; println!("chunk_offset: {}", chunk_offset); - let (x1, x2) = chunk_by_offset(src, chunk_offset, skip); + let (x1, x2) = chunk_by_offset(&src, chunk_offset, skip); println!("cos len: {}", cos.len()); println!("sin len: {}", sin.len()); @@ -146,16 +211,6 @@ fn rope(src: &[f32], shape: &Shape, dim: usize, base: f32, offset: usize) -> Vec .map(|(i, x)| (x * cos[i % cos.len()], x * sin[i % sin.len()])) .unzip(); - println!("x1: {:?}", x1); - println!("x2: {:?}", x2); - println!("sin: {:?}", sin); - println!("cos: {:?}", cos); - - println!("x1_sin: {:?}", x1_sin); - println!("x1_cos: {:?}", x1_cos); - println!("x2_sin: {:?}", x2_sin); - println!("x2_cos: {:?}", x2_cos); - x1_cos.iter().zip(x2_sin).for_each(|(x1_cos, x2_sin)| { intermediate.push(x1_cos - x2_sin); }); @@ -171,12 +226,52 @@ fn rope(src: &[f32], shape: &Shape, dim: usize, base: f32, offset: usize) -> Vec println!("out_shape: {:?}", out_shape); let skip = head_dim.abs_diff(dim); - let mut dst = merge(&intermediate, chunk_offset, skip); + let mut dst = merge(&intermediate, half_dim, skip); if dim < head_dim { - dst.append(&mut src[dim..].to_vec()) + let offset = (el_count / head_dim) * dim; + let appendix = &mut src[offset..].to_vec(); + dst.append(appendix); } println!("dst: {:?}", dst); println!("dst len: {}", dst.len()); dst } + +fn rope_2(src: Vec, shape: &Shape, dim: usize, base: f32, offset: usize) -> Vec { + println!("Ratchet RoPE"); + let [batches, num_heads, seq_len, head_dim] = shape.try_into().unwrap(); + let el_count = batches * num_heads * seq_len * head_dim; + + let half_dim = dim / 2; + let (sin, cos) = calculate_sincos(dim, seq_len, base, offset); + + println!("cos: {:?}", cos); + println!("sin: {:?}", sin); + + let src = transpose(src, &shape, 1, 2).unwrap(); + let mut dst = vec![0.0; el_count]; + let b = batches; + let t = num_heads; + let h = seq_len; + let d = head_dim; + src.chunks(t * h * d) + .zip(dst.chunks_mut(t * h * d)) + .for_each(|(src, dst)| { + for i_t in 0..t { + for i_d in 0..d / 2 { + let i_cs = i_t * (d / 2) + i_d; + for i_h in 0..h { + let i1 = i_t * h * d + i_h * d + i_d; + let i2 = i1 + d / 2; + dst[i1] = src[i1] * cos[i_cs] - src[i2] * sin[i_cs]; + dst[i2] = src[i1] * sin[i_cs] + src[i2] * cos[i_cs]; + } + } + } + }); + + let dst = transpose(dst, &shape, 1, 2).unwrap(); + + dst +} diff --git a/crates/ratchet-core/src/ops/matmul/mod.rs b/crates/ratchet-core/src/ops/matmul/mod.rs index c92e4bf3..d3a711b0 100644 --- a/crates/ratchet-core/src/ops/matmul/mod.rs +++ b/crates/ratchet-core/src/ops/matmul/mod.rs @@ -12,7 +12,7 @@ use std::{cmp::Ordering, mem}; use crate::{ gpu::{BindGroupLayoutDescriptor, CpuUniform}, - quantize, rvec, DType, Device, GPUOperation, Kernel, KernelElement, KernelKey, KernelMetadata, + rvec, DType, Device, GPUOperation, Kernel, KernelElement, KernelKey, KernelMetadata, KernelRenderable, KernelSource, OpGuards, Operation, OperationError, RVec, Shape, StorageView, Strides, Tensor, WorkgroupSize, Workload, Q4_KF, Q4_KH, Q8_0F, Q8_0H, }; @@ -754,7 +754,7 @@ mod tests { use crate::test_util::run_py_prg; - use crate::{shape, Device, DeviceRequest}; + use crate::{quantize, shape, Device, DeviceRequest}; use super::*; diff --git a/crates/ratchet-core/src/ops/rope.rs b/crates/ratchet-core/src/ops/rope.rs index e5ff11f8..c84b6b3f 100644 --- a/crates/ratchet-core/src/ops/rope.rs +++ b/crates/ratchet-core/src/ops/rope.rs @@ -376,11 +376,11 @@ def mlx_rope(input, dim, offset): fn debug_rope_cpu() { let prob = RoPEProblem { BS: 1, - NH: 2, - SL: 1, - HD: 32, + NH: 1, + SL: 128, + HD: 16, dim: 16, - offset: 1, + offset: 0, }; println!("{prob:?}"); diff --git a/crates/ratchet-core/src/strides.rs b/crates/ratchet-core/src/strides.rs index 2e57d20d..0762f6c4 100644 --- a/crates/ratchet-core/src/strides.rs +++ b/crates/ratchet-core/src/strides.rs @@ -1,4 +1,5 @@ use std::ops::Index; +use std::slice::Iter; use crate::{rvec, RVec, Shape}; use encase::impl_wrapper; @@ -13,6 +14,10 @@ impl Strides { self.0.to_vec() } + pub fn iter(&self) -> Iter<'_, isize> { + self.0.iter() + } + pub fn transpose(&mut self) { let rank = self.0.len(); if rank < 2 { @@ -53,6 +58,12 @@ impl From<&Shape> for Strides { } } +impl From> for Strides { + fn from(strides: Vec) -> Self { + Self(strides.into()) + } +} + impl From<&Strides> for [u32; 3] { fn from(strides: &Strides) -> Self { assert!(strides.0.len() <= 3); From d5fb9f823199b05698a402fcb4ebead8b315b6f4 Mon Sep 17 00:00:00 2001 From: FL33TW00D Date: Wed, 2 Oct 2024 15:28:11 +0200 Subject: [PATCH 15/32] chore: focus on theta --- crates/ratchet-core/src/cpu/rope.rs | 20 +++----------------- crates/ratchet-core/src/ops/rope.rs | 6 +++--- crates/ratchet-core/src/tensor.rs | 27 ++++++++++++++++++--------- 3 files changed, 24 insertions(+), 29 deletions(-) diff --git a/crates/ratchet-core/src/cpu/rope.rs b/crates/ratchet-core/src/cpu/rope.rs index 5b37954a..2a05d6ae 100644 --- a/crates/ratchet-core/src/cpu/rope.rs +++ b/crates/ratchet-core/src/cpu/rope.rs @@ -1,7 +1,6 @@ use crate::{ cpu::{cpu_store_result, gemm::gemm}, - shape, DType, OperationError, RoPE, Shape, StridedIterator, Strides, Tensor, TensorDType, - TensorError, Unary, + shape, DType, OperationError, RoPE, Shape, StridedIterator, Strides, Tensor, }; use anyhow::anyhow; @@ -55,6 +54,8 @@ fn calculate_sincos(dim: usize, seq_len: usize, base: f32, offset: usize) -> (Ve ) .unwrap(); + println!("THETA: {:?}", theta); + let (sin_theta, cos_theta) = theta.iter().map(|i| i.sin_cos()).unzip(); (sin_theta, cos_theta) } @@ -183,22 +184,13 @@ fn rope(src: Vec, shape: &Shape, dim: usize, base: f32, offset: usize) -> V let half_dim = dim / 2; let (sin, cos) = calculate_sincos(dim, seq_len, base, offset); - println!("cos: {:?}", cos); - println!("sin: {:?}", sin); let mut intermediate = Vec::with_capacity(el_count); let chunk_offset = half_dim; let skip = 0; - println!("chunk_offset: {}", chunk_offset); let (x1, x2) = chunk_by_offset(&src, chunk_offset, skip); - println!("cos len: {}", cos.len()); - println!("sin len: {}", sin.len()); - println!("src len: {}", src.len()); - println!("x1 len: {}", x1.len()); - println!("x2 len: {}", x2.len()); - let (x1_cos, x1_sin): (Vec, Vec) = x1 .iter() .enumerate() @@ -219,11 +211,7 @@ fn rope(src: Vec, shape: &Shape, dim: usize, base: f32, offset: usize) -> V intermediate.push(x1_sin + x2_cos); }); - println!("intermediate: {:?}", intermediate); - println!("intermediate len: {}", intermediate.len()); - let out_shape = shape!(batches, num_heads, seq_len, head_dim); - println!("out_shape: {:?}", out_shape); let skip = head_dim.abs_diff(dim); let mut dst = merge(&intermediate, half_dim, skip); @@ -233,8 +221,6 @@ fn rope(src: Vec, shape: &Shape, dim: usize, base: f32, offset: usize) -> V let appendix = &mut src[offset..].to_vec(); dst.append(appendix); } - println!("dst: {:?}", dst); - println!("dst len: {}", dst.len()); dst } diff --git a/crates/ratchet-core/src/ops/rope.rs b/crates/ratchet-core/src/ops/rope.rs index c84b6b3f..933c2395 100644 --- a/crates/ratchet-core/src/ops/rope.rs +++ b/crates/ratchet-core/src/ops/rope.rs @@ -376,10 +376,10 @@ def mlx_rope(input, dim, offset): fn debug_rope_cpu() { let prob = RoPEProblem { BS: 1, - NH: 1, - SL: 128, + NH: 2, + SL: 16, HD: 16, - dim: 16, + dim: 8, offset: 0, }; println!("{prob:?}"); diff --git a/crates/ratchet-core/src/tensor.rs b/crates/ratchet-core/src/tensor.rs index c655d2ac..5b8c108e 100644 --- a/crates/ratchet-core/src/tensor.rs +++ b/crates/ratchet-core/src/tensor.rs @@ -83,15 +83,23 @@ impl Tensor { impl std::fmt::Debug for Tensor { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let storage_fmt = self.storage().as_ref().map(|s| s.dump(self.dt(), false)); - let (id, op) = (self.id(), self.op()); - f.debug_struct("Tensor") - .field("id", &id) - .field("shape", &self.shape()) - .field("dt", &self.dt()) - .field("op", &op) - .field("storage", &storage_fmt) - .finish() + match self.device() { + Device::CPU => match self.dt() { + DType::F32 => self.to_ndarray_view::().fmt(f), + _ => unimplemented!("Debug not implemented for {:?}", self.dt()), + }, + Device::GPU(_) => { + let storage_fmt = self.storage().as_ref().map(|s| s.dump(self.dt(), false)); + let (id, op) = (self.id(), self.op()); + f.debug_struct("Tensor") + .field("id", &id) + .field("shape", &self.shape()) + .field("dt", &self.dt()) + .field("op", &op) + .field("storage", &storage_fmt) + .finish() + } + } } } @@ -263,6 +271,7 @@ macro_rules! impl_binary_op { macro_rules! impl_unary_op { ($method_name:ident, $op:expr) => { + #[allow(clippy::should_implement_trait)] pub fn $method_name(self) -> anyhow::Result { let device = self.device.clone(); let unary = Unary::new(self.clone(), $op); From 44bc1ec58456012db07c6edc08f090687de180ea Mon Sep 17 00:00:00 2001 From: FL33TW00D Date: Wed, 2 Oct 2024 16:53:23 +0200 Subject: [PATCH 16/32] chore: theta matches --- crates/ratchet-core/src/cpu/rope.rs | 14 +++++++++++--- crates/ratchet-core/src/ops/rope.rs | 1 + 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/crates/ratchet-core/src/cpu/rope.rs b/crates/ratchet-core/src/cpu/rope.rs index 2a05d6ae..90bbc4dd 100644 --- a/crates/ratchet-core/src/cpu/rope.rs +++ b/crates/ratchet-core/src/cpu/rope.rs @@ -22,23 +22,31 @@ pub fn cpu_rope(op: RoPE, dst: Tensor) -> Result { fn calculate_sincos(dim: usize, seq_len: usize, base: f32, offset: usize) -> (Vec, Vec) { let half_dim = dim / 2; + println!("Half dim: {}", half_dim); let positions = (offset..seq_len + offset) .map(|x| x as f32) .collect::>(); + println!("Positions: {:?}", positions); + let log_base = base.ln(); + + println!("Log base: {}", log_base); + let inv_freqs = (0..half_dim) .map(|i| -(i as f32)) .map(|i| i * log_base / half_dim as f32) .map(f32::exp) .collect::>(); - let p_shape = shape!(half_dim, 1); + println!("Inverse Frequencies: {:?}", inv_freqs); + + let p_shape = shape!(seq_len, 1); let p_strides = Strides::from(&p_shape); let i_shape = shape!(1, half_dim); let i_strides = Strides::from(&i_shape); - let dst_strides = Strides::from(&shape!(half_dim, half_dim)); + let dst_strides = Strides::from(&shape!(seq_len, half_dim)); let theta = gemm( &positions, &p_shape, @@ -48,7 +56,7 @@ fn calculate_sincos(dim: usize, seq_len: usize, base: f32, offset: usize) -> (Ve &i_strides, &dst_strides, 1, - half_dim, + seq_len, half_dim, 1, ) diff --git a/crates/ratchet-core/src/ops/rope.rs b/crates/ratchet-core/src/ops/rope.rs index 933c2395..cd988c6a 100644 --- a/crates/ratchet-core/src/ops/rope.rs +++ b/crates/ratchet-core/src/ops/rope.rs @@ -304,6 +304,7 @@ def mlx_rope(input, dim, offset): let n = shape.numel(); let data = (0..n).map(|x| x as f32).collect::>(); let a = Tensor::from_data(data, shape, Device::CPU); + println!("Input tensor: {:?}", a); let ground = ground_truth(&a, dim, offset).unwrap(); let a = a.to(&device).unwrap(); From 81f4bfc1074f8bc28fdba51f03ce9eeeb25227af Mon Sep 17 00:00:00 2001 From: FL33TW00D Date: Wed, 2 Oct 2024 17:50:15 +0200 Subject: [PATCH 17/32] chore: theta matches --- crates/ratchet-core/src/cpu/rope.rs | 32 ++++++++--------------------- 1 file changed, 8 insertions(+), 24 deletions(-) diff --git a/crates/ratchet-core/src/cpu/rope.rs b/crates/ratchet-core/src/cpu/rope.rs index 90bbc4dd..c72b1d0f 100644 --- a/crates/ratchet-core/src/cpu/rope.rs +++ b/crates/ratchet-core/src/cpu/rope.rs @@ -20,28 +20,19 @@ pub fn cpu_rope(op: RoPE, dst: Tensor) -> Result { Ok(dst) } -fn calculate_sincos(dim: usize, seq_len: usize, base: f32, offset: usize) -> (Vec, Vec) { +fn compute_theta(dim: usize, seq_len: usize, base: f32, offset: usize) -> Vec { let half_dim = dim / 2; - println!("Half dim: {}", half_dim); let positions = (offset..seq_len + offset) .map(|x| x as f32) .collect::>(); - println!("Positions: {:?}", positions); - - let log_base = base.ln(); - - println!("Log base: {}", log_base); - let inv_freqs = (0..half_dim) .map(|i| -(i as f32)) - .map(|i| i * log_base / half_dim as f32) + .map(|i| i * base.ln() / half_dim as f32) .map(f32::exp) .collect::>(); - println!("Inverse Frequencies: {:?}", inv_freqs); - let p_shape = shape!(seq_len, 1); let p_strides = Strides::from(&p_shape); let i_shape = shape!(1, half_dim); @@ -62,10 +53,7 @@ fn calculate_sincos(dim: usize, seq_len: usize, base: f32, offset: usize) -> (Ve ) .unwrap(); - println!("THETA: {:?}", theta); - - let (sin_theta, cos_theta) = theta.iter().map(|i| i.sin_cos()).unzip(); - (sin_theta, cos_theta) + theta } #[inline] @@ -190,7 +178,8 @@ fn rope(src: Vec, shape: &Shape, dim: usize, base: f32, offset: usize) -> V let el_count = batches * num_heads * seq_len * head_dim; let half_dim = dim / 2; - let (sin, cos) = calculate_sincos(dim, seq_len, base, offset); + let theta = compute_theta(dim, seq_len, base, offset); + let (sin, cos): (Vec, Vec) = theta.iter().map(|i| i.sin_cos()).unzip(); let mut intermediate = Vec::with_capacity(el_count); @@ -219,8 +208,6 @@ fn rope(src: Vec, shape: &Shape, dim: usize, base: f32, offset: usize) -> V intermediate.push(x1_sin + x2_cos); }); - let out_shape = shape!(batches, num_heads, seq_len, head_dim); - let skip = head_dim.abs_diff(dim); let mut dst = merge(&intermediate, half_dim, skip); @@ -237,15 +224,14 @@ fn rope_2(src: Vec, shape: &Shape, dim: usize, base: f32, offset: usize) -> let [batches, num_heads, seq_len, head_dim] = shape.try_into().unwrap(); let el_count = batches * num_heads * seq_len * head_dim; - let half_dim = dim / 2; - let (sin, cos) = calculate_sincos(dim, seq_len, base, offset); + let theta = compute_theta(dim, seq_len, base, offset); + let (sin, cos): (Vec, Vec) = theta.iter().map(|i| i.sin_cos()).unzip(); println!("cos: {:?}", cos); println!("sin: {:?}", sin); let src = transpose(src, &shape, 1, 2).unwrap(); let mut dst = vec![0.0; el_count]; - let b = batches; let t = num_heads; let h = seq_len; let d = head_dim; @@ -265,7 +251,5 @@ fn rope_2(src: Vec, shape: &Shape, dim: usize, base: f32, offset: usize) -> } }); - let dst = transpose(dst, &shape, 1, 2).unwrap(); - - dst + transpose(dst, &shape, 1, 2).unwrap() } From 82435ebb2c28beccf274f5b9e2512766aa375e3a Mon Sep 17 00:00:00 2001 From: FL33TW00D Date: Thu, 3 Oct 2024 16:26:15 +0200 Subject: [PATCH 18/32] chore: R1 and R2 match --- crates/ratchet-core/src/cpu/gemm.rs | 2 +- crates/ratchet-core/src/cpu/mod.rs | 16 ++-- crates/ratchet-core/src/cpu/rope.rs | 112 +++++++++++++++------------ crates/ratchet-core/src/cpu/slice.rs | 3 + crates/ratchet-core/src/op.rs | 2 +- crates/ratchet-core/src/tensor.rs | 2 +- 6 files changed, 76 insertions(+), 61 deletions(-) create mode 100644 crates/ratchet-core/src/cpu/slice.rs diff --git a/crates/ratchet-core/src/cpu/gemm.rs b/crates/ratchet-core/src/cpu/gemm.rs index fdba0622..b932b37b 100644 --- a/crates/ratchet-core/src/cpu/gemm.rs +++ b/crates/ratchet-core/src/cpu/gemm.rs @@ -156,7 +156,7 @@ fn gemm_impl( } impl CPUOperation for Matmul { - fn apply(&self, dst: Tensor) -> Result { + fn apply_cpu(&self, dst: Tensor) -> Result { fn run_gemm( spec: MatmulSpec, lhs: &Tensor, diff --git a/crates/ratchet-core/src/cpu/mod.rs b/crates/ratchet-core/src/cpu/mod.rs index 6d5182ff..3646373a 100644 --- a/crates/ratchet-core/src/cpu/mod.rs +++ b/crates/ratchet-core/src/cpu/mod.rs @@ -168,7 +168,7 @@ macro_rules! impl_cpu_unary { impl_cpu_unary_wrapper!($dtype, $conv); impl CPUOperation for CPU<$dtype, Unary> { - fn apply(&self, dst: Tensor) -> Result { + fn apply_cpu(&self, dst: Tensor) -> Result { match self.op.op() { UnaryOp::Gelu => Self::gelu(self.op.input(), dst), UnaryOp::Tanh => Self::tanh(self.op.input(), dst), @@ -196,9 +196,9 @@ impl_cpu_unary!(bf16, bf16::from_f32); pub fn cpu_unary(unary: Unary, dst: Tensor) -> Result { match dst.dt() { - DType::F32 => CPU::::new(unary).apply(dst), - DType::F16 => CPU::::new(unary).apply(dst), - DType::BF16 => CPU::::new(unary).apply(dst), + DType::F32 => CPU::::new(unary).apply_cpu(dst), + DType::F16 => CPU::::new(unary).apply_cpu(dst), + DType::BF16 => CPU::::new(unary).apply_cpu(dst), _ => todo!(), } } @@ -222,7 +222,7 @@ macro_rules! impl_cpu_binary { } impl CPUOperation for CPU<$dtype, Binary> { - fn apply(&self, dst: Tensor) -> Result { + fn apply_cpu(&self, dst: Tensor) -> Result { match self.op.op() { BinaryOp::Add => Self::add(self.op.lhs(), self.op.rhs(), dst), BinaryOp::Sub => Self::sub(self.op.lhs(), self.op.rhs(), dst), @@ -240,9 +240,9 @@ impl_cpu_binary!(bf16); pub fn cpu_binary(binary: Binary, dst: Tensor) -> Result { match dst.dt() { - DType::F32 => CPU::::new(binary).apply(dst), - DType::F16 => CPU::::new(binary).apply(dst), - DType::BF16 => CPU::::new(binary).apply(dst), + DType::F32 => CPU::::new(binary).apply_cpu(dst), + DType::F16 => CPU::::new(binary).apply_cpu(dst), + DType::BF16 => CPU::::new(binary).apply_cpu(dst), _ => todo!(), } } diff --git a/crates/ratchet-core/src/cpu/rope.rs b/crates/ratchet-core/src/cpu/rope.rs index c72b1d0f..8021a8a3 100644 --- a/crates/ratchet-core/src/cpu/rope.rs +++ b/crates/ratchet-core/src/cpu/rope.rs @@ -105,30 +105,29 @@ fn merge(data: &[f32], offset: usize, skip: usize) -> Vec { } fn slice(src: &[f32], start: &[usize], stop: &[usize]) -> Vec { - let stop_numel: usize = stop.iter().product(); - let start_numel: usize = stop.iter().product(); - assert!(stop_numel >= start_numel); - - let mut dst = vec![0.0; stop_numel - start_numel]; - - /* - start: [0, 0, 0, 8] - stop: [1, 1, 1, 16] - for - */ - - let mut src_idx = 0; - let mut dst_idx = 0; - for i in 0..start.len() { - let mut src_stride = start[i]; - let mut dst_stride = 0; - while src_stride < stop[i] { - dst[dst_idx] = src[src_idx]; - src_idx += src_stride; - dst_idx += dst_stride; - src_stride += 1; - dst_stride += 1; + assert!(start.len() == stop.len()); + start.iter().zip(stop.iter()).for_each(|(s, t)| { + assert!(s < t); + }); + + let src_shape = [2, 16, 16]; // Corrected input shape + let src_strides = [16 * 16, 16, 1]; + + let delta: Vec = stop.iter().zip(start.iter()).map(|(s, t)| s - t).collect(); + let dst_shape: Vec = delta.clone(); + let dst_numel: usize = delta.iter().product(); + + let mut dst = vec![0.0; dst_numel]; + + for i in 0..dst_numel { + let mut src_index = 0; + let mut tmp = i; + for d in 0..delta.len() { + let coord = tmp / dst_shape[d + 1..].iter().product::().max(1); + tmp %= dst_shape[d + 1..].iter().product::().max(1); + src_index += (coord + start[d]) * src_strides[d]; } + dst[i] = src[src_index]; } dst @@ -175,48 +174,61 @@ fn transpose( fn rope(src: Vec, shape: &Shape, dim: usize, base: f32, offset: usize) -> Vec { println!("Ratchet RoPE"); let [batches, num_heads, seq_len, head_dim] = shape.try_into().unwrap(); - let el_count = batches * num_heads * seq_len * head_dim; let half_dim = dim / 2; let theta = compute_theta(dim, seq_len, base, offset); + println!("Theta: {:?}", theta); let (sin, cos): (Vec, Vec) = theta.iter().map(|i| i.sin_cos()).unzip(); + println!("Cos: {:?}", cos); + println!("Sin: {:?}", sin); - let mut intermediate = Vec::with_capacity(el_count); + println!("Cos length: {:?}", cos.len()); + println!("Sin length: {:?}", sin.len()); - let chunk_offset = half_dim; - let skip = 0; + let x1 = slice(&src, &[0, 0, 0], &[num_heads, seq_len, half_dim]); + let x2 = slice(&src, &[0, 0, half_dim], &[num_heads, seq_len, dim]); + println!("X1: {:?}", x1); + println!("X1 length: {:?}", x1.len()); + println!("X2: {:?}", x2); + println!("X2 length: {:?}", x2.len()); - let (x1, x2) = chunk_by_offset(&src, chunk_offset, skip); - - let (x1_cos, x1_sin): (Vec, Vec) = x1 + let x1_cos = x1 .iter() .enumerate() - .map(|(i, x)| (x * cos[i % cos.len()], x * sin[i % sin.len()])) - .unzip(); - - let (x2_cos, x2_sin): (Vec, Vec) = x2 + .map(|(i, x)| x * cos[i % cos.len()]) + .collect::>(); + let x2_sin = x2 .iter() .enumerate() - .map(|(i, x)| (x * cos[i % cos.len()], x * sin[i % sin.len()])) - .unzip(); + .map(|(i, x)| x * sin[i % sin.len()]) + .collect::>(); - x1_cos.iter().zip(x2_sin).for_each(|(x1_cos, x2_sin)| { - intermediate.push(x1_cos - x2_sin); - }); + let r1 = x1_cos + .iter() + .zip(x2_sin.iter()) + .map(|(x1, x2)| x1 - x2) + .collect::>(); - x1_sin.iter().zip(x2_cos).for_each(|(x1_sin, x2_cos)| { - intermediate.push(x1_sin + x2_cos); - }); + let x1_sin = x1 + .iter() + .enumerate() + .map(|(i, x)| x * sin[i % sin.len()]) + .collect::>(); + let x2_cos = x2 + .iter() + .enumerate() + .map(|(i, x)| x * cos[i % cos.len()]) + .collect::>(); + let r2 = x1_sin + .iter() + .zip(x2_cos.iter()) + .map(|(x1, x2)| x1 + x2) + .collect::>(); - let skip = head_dim.abs_diff(dim); - let mut dst = merge(&intermediate, half_dim, skip); + println!("R1: {:?}", r1); + println!("R2: {:?}", r2); - if dim < head_dim { - let offset = (el_count / head_dim) * dim; - let appendix = &mut src[offset..].to_vec(); - dst.append(appendix); - } - dst + vec![] } fn rope_2(src: Vec, shape: &Shape, dim: usize, base: f32, offset: usize) -> Vec { diff --git a/crates/ratchet-core/src/cpu/slice.rs b/crates/ratchet-core/src/cpu/slice.rs new file mode 100644 index 00000000..0b1c2b4e --- /dev/null +++ b/crates/ratchet-core/src/cpu/slice.rs @@ -0,0 +1,3 @@ +use crate::{Slice, Tensor}; + +pub fn cpu_slice(op: Slice, dst: Tensor) -> Result {} diff --git a/crates/ratchet-core/src/op.rs b/crates/ratchet-core/src/op.rs index 10bb5ab4..d3598fa4 100644 --- a/crates/ratchet-core/src/op.rs +++ b/crates/ratchet-core/src/op.rs @@ -363,5 +363,5 @@ pub trait GPUOperation: Operation { } pub trait CPUOperation: Operation { - fn apply(&self, dst: Tensor) -> Result; + fn apply_cpu(&self, dst: Tensor) -> Result; } diff --git a/crates/ratchet-core/src/tensor.rs b/crates/ratchet-core/src/tensor.rs index 5b8c108e..99cc5917 100644 --- a/crates/ratchet-core/src/tensor.rs +++ b/crates/ratchet-core/src/tensor.rs @@ -759,7 +759,7 @@ impl Tensor { match self.op().clone() { LazyOp::Binary(b) => cpu_binary(b, dst).ok(), LazyOp::Cast(c) => cpu_cast(c, dst).ok(), - LazyOp::Matmul(m) => m.apply(dst).ok(), + LazyOp::Matmul(m) => m.apply_cpu(dst).ok(), LazyOp::Softmax(_s) => todo!(), LazyOp::RoPE(r) => cpu_rope(r, dst).ok(), LazyOp::Unary(u) => cpu_unary(u, dst).ok(), From ca5f5a79c4cba2a0b7e43b1fee487f97867e1290 Mon Sep 17 00:00:00 2001 From: FL33TW00D Date: Thu, 3 Oct 2024 16:53:26 +0200 Subject: [PATCH 19/32] chore: cleaning --- crates/ratchet-core/src/cpu/rope.rs | 37 ++--------------------------- 1 file changed, 2 insertions(+), 35 deletions(-) diff --git a/crates/ratchet-core/src/cpu/rope.rs b/crates/ratchet-core/src/cpu/rope.rs index 8021a8a3..f7c4271e 100644 --- a/crates/ratchet-core/src/cpu/rope.rs +++ b/crates/ratchet-core/src/cpu/rope.rs @@ -228,40 +228,7 @@ fn rope(src: Vec, shape: &Shape, dim: usize, base: f32, offset: usize) -> V println!("R1: {:?}", r1); println!("R2: {:?}", r2); - vec![] -} - -fn rope_2(src: Vec, shape: &Shape, dim: usize, base: f32, offset: usize) -> Vec { - println!("Ratchet RoPE"); - let [batches, num_heads, seq_len, head_dim] = shape.try_into().unwrap(); - let el_count = batches * num_heads * seq_len * head_dim; + if dim < shape[3] {} - let theta = compute_theta(dim, seq_len, base, offset); - let (sin, cos): (Vec, Vec) = theta.iter().map(|i| i.sin_cos()).unzip(); - - println!("cos: {:?}", cos); - println!("sin: {:?}", sin); - - let src = transpose(src, &shape, 1, 2).unwrap(); - let mut dst = vec![0.0; el_count]; - let t = num_heads; - let h = seq_len; - let d = head_dim; - src.chunks(t * h * d) - .zip(dst.chunks_mut(t * h * d)) - .for_each(|(src, dst)| { - for i_t in 0..t { - for i_d in 0..d / 2 { - let i_cs = i_t * (d / 2) + i_d; - for i_h in 0..h { - let i1 = i_t * h * d + i_h * d + i_d; - let i2 = i1 + d / 2; - dst[i1] = src[i1] * cos[i_cs] - src[i2] * sin[i_cs]; - dst[i2] = src[i1] * sin[i_cs] + src[i2] * cos[i_cs]; - } - } - } - }); - - transpose(dst, &shape, 1, 2).unwrap() + vec![] } From 88e7c07d94ef835de0de81684770151011183b59 Mon Sep 17 00:00:00 2001 From: FL33TW00D Date: Fri, 4 Oct 2024 11:14:56 +0200 Subject: [PATCH 20/32] chore: RoPE works but is shit --- crates/ratchet-core/src/cpu/rope.rs | 72 +++++++++-------------------- 1 file changed, 22 insertions(+), 50 deletions(-) diff --git a/crates/ratchet-core/src/cpu/rope.rs b/crates/ratchet-core/src/cpu/rope.rs index f7c4271e..9ecb0823 100644 --- a/crates/ratchet-core/src/cpu/rope.rs +++ b/crates/ratchet-core/src/cpu/rope.rs @@ -1,4 +1,5 @@ use crate::{ + concat, cpu::{cpu_store_result, gemm::gemm}, shape, DType, OperationError, RoPE, Shape, StridedIterator, Strides, Tensor, }; @@ -56,54 +57,6 @@ fn compute_theta(dim: usize, seq_len: usize, base: f32, offset: usize) -> Vec (Vec, Vec) { - let mut x1 = Vec::with_capacity(data.len() / 2); - let mut x2 = Vec::with_capacity(data.len() / 2); - - let mut start = 0; - let mut stop = offset; - while stop < data.len() { - let mut chunk = data[start..stop].to_vec(); - x1.append(&mut chunk); - start += offset; - stop += offset; - - let mut chunk = data[start..stop].to_vec(); - x2.append(&mut chunk); - start += offset; - stop += offset; - - start += skip; - stop += skip; - } - (x1.to_vec(), x2.to_vec()) -} - -#[inline] -fn merge(data: &[f32], offset: usize, skip: usize) -> Vec { - let n = data.len(); - let mid = n / 2; - let mut interleaved = Vec::with_capacity(n); - - let mut start = 0; - let mut stop = offset; - while stop + mid <= n { - let mut chunk = data[start..stop].to_vec(); - interleaved.append(&mut chunk); - - let mut chunk = data[start + mid..stop + mid].to_vec(); - interleaved.append(&mut chunk); - - start += offset; - stop += offset; - - start += skip; - stop += skip; - } - interleaved -} - fn slice(src: &[f32], start: &[usize], stop: &[usize]) -> Vec { assert!(start.len() == stop.len()); start.iter().zip(stop.iter()).for_each(|(s, t)| { @@ -192,6 +145,8 @@ fn rope(src: Vec, shape: &Shape, dim: usize, base: f32, offset: usize) -> V println!("X2: {:?}", x2); println!("X2 length: {:?}", x2.len()); + //zip and repeat + //`multiply` as an operation that deals with broadcasting let x1_cos = x1 .iter() .enumerate() @@ -203,11 +158,13 @@ fn rope(src: Vec, shape: &Shape, dim: usize, base: f32, offset: usize) -> V .map(|(i, x)| x * sin[i % sin.len()]) .collect::>(); + let mut outs = vec![]; let r1 = x1_cos .iter() .zip(x2_sin.iter()) .map(|(x1, x2)| x1 - x2) .collect::>(); + outs.push(r1.clone()); let x1_sin = x1 .iter() @@ -224,11 +181,26 @@ fn rope(src: Vec, shape: &Shape, dim: usize, base: f32, offset: usize) -> V .zip(x2_cos.iter()) .map(|(x1, x2)| x1 + x2) .collect::>(); + outs.push(r2.clone()); println!("R1: {:?}", r1); println!("R2: {:?}", r2); - if dim < shape[3] {} + if dim < shape[3] { + //outs.push_back(slice(x, {0, 0, dims}, x.shape(), s)); + outs.push(slice(&src, &[0, 0, dim], &[num_heads, seq_len, head_dim])); + } - vec![] + let (o0, o1, o2) = (outs[0].clone(), outs[1].clone(), outs[2].clone()); + + let to_cat = [ + (&shape![num_heads, seq_len, half_dim], o0), + (&shape![num_heads, seq_len, half_dim], o1), + (&shape![num_heads, seq_len, head_dim - dim], o2), + ]; + + let dst_shape = shape![num_heads, seq_len, head_dim]; + let mut dst = vec![0.0f32; dst_shape.numel()]; + concat(to_cat.as_slice(), 2, &dst_shape, &mut dst).unwrap(); + dst } From 4d63692d94a308f22b10f0e878a25fd2b5e7c64b Mon Sep 17 00:00:00 2001 From: FL33TW00D Date: Fri, 4 Oct 2024 11:33:12 +0200 Subject: [PATCH 21/32] chore: RoPE doesn't work --- crates/ratchet-core/src/cpu/rope.rs | 1 - crates/ratchet-core/src/ops/rope.rs | 8 ++++---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/crates/ratchet-core/src/cpu/rope.rs b/crates/ratchet-core/src/cpu/rope.rs index 9ecb0823..1da86e6c 100644 --- a/crates/ratchet-core/src/cpu/rope.rs +++ b/crates/ratchet-core/src/cpu/rope.rs @@ -187,7 +187,6 @@ fn rope(src: Vec, shape: &Shape, dim: usize, base: f32, offset: usize) -> V println!("R2: {:?}", r2); if dim < shape[3] { - //outs.push_back(slice(x, {0, 0, dims}, x.shape(), s)); outs.push(slice(&src, &[0, 0, dim], &[num_heads, seq_len, head_dim])); } diff --git a/crates/ratchet-core/src/ops/rope.rs b/crates/ratchet-core/src/ops/rope.rs index cd988c6a..fc71ee9c 100644 --- a/crates/ratchet-core/src/ops/rope.rs +++ b/crates/ratchet-core/src/ops/rope.rs @@ -377,10 +377,10 @@ def mlx_rope(input, dim, offset): fn debug_rope_cpu() { let prob = RoPEProblem { BS: 1, - NH: 2, - SL: 16, - HD: 16, - dim: 8, + NH: 1, + SL: 2, + HD: 128, + dim: 96, offset: 0, }; println!("{prob:?}"); From 1d93205cb9721d3fd5e30734ee3b73e29360f55e Mon Sep 17 00:00:00 2001 From: FL33TW00D Date: Fri, 4 Oct 2024 14:18:35 +0200 Subject: [PATCH 22/32] chore: not quite right --- crates/ratchet-core/src/cpu/rope.rs | 32 +++++++++++++++++++++-------- crates/ratchet-core/src/ops/rope.rs | 9 ++++---- crates/ratchet-core/src/strides.rs | 4 ++++ 3 files changed, 33 insertions(+), 12 deletions(-) diff --git a/crates/ratchet-core/src/cpu/rope.rs b/crates/ratchet-core/src/cpu/rope.rs index 1da86e6c..6078a897 100644 --- a/crates/ratchet-core/src/cpu/rope.rs +++ b/crates/ratchet-core/src/cpu/rope.rs @@ -57,15 +57,13 @@ fn compute_theta(dim: usize, seq_len: usize, base: f32, offset: usize) -> Vec Vec { +fn slice(src: &[f32], src_strides: &Strides, start: &[usize], stop: &[usize]) -> Vec { assert!(start.len() == stop.len()); + assert!(start.len() == src_strides.rank()); start.iter().zip(stop.iter()).for_each(|(s, t)| { assert!(s < t); }); - let src_shape = [2, 16, 16]; // Corrected input shape - let src_strides = [16 * 16, 16, 1]; - let delta: Vec = stop.iter().zip(start.iter()).map(|(s, t)| s - t).collect(); let dst_shape: Vec = delta.clone(); let dst_numel: usize = delta.iter().product(); @@ -78,7 +76,7 @@ fn slice(src: &[f32], start: &[usize], stop: &[usize]) -> Vec { for d in 0..delta.len() { let coord = tmp / dst_shape[d + 1..].iter().product::().max(1); tmp %= dst_shape[d + 1..].iter().product::().max(1); - src_index += (coord + start[d]) * src_strides[d]; + src_index += (coord + start[d]) * src_strides[d] as usize; } dst[i] = src[src_index]; } @@ -138,8 +136,20 @@ fn rope(src: Vec, shape: &Shape, dim: usize, base: f32, offset: usize) -> V println!("Cos length: {:?}", cos.len()); println!("Sin length: {:?}", sin.len()); - let x1 = slice(&src, &[0, 0, 0], &[num_heads, seq_len, half_dim]); - let x2 = slice(&src, &[0, 0, half_dim], &[num_heads, seq_len, dim]); + println!("HALF DIM: {:?}", half_dim); + let src_strides = Strides::from(shape); + let x1 = slice( + &src, + &src_strides, + &[0, 0, 0, 0], + &[batches, num_heads, seq_len, half_dim], + ); + let x2 = slice( + &src, + &src_strides, + &[0, 0, 0, half_dim], + &[batches, num_heads, seq_len, dim], + ); println!("X1: {:?}", x1); println!("X1 length: {:?}", x1.len()); println!("X2: {:?}", x2); @@ -187,7 +197,12 @@ fn rope(src: Vec, shape: &Shape, dim: usize, base: f32, offset: usize) -> V println!("R2: {:?}", r2); if dim < shape[3] { - outs.push(slice(&src, &[0, 0, dim], &[num_heads, seq_len, head_dim])); + outs.push(slice( + &src, + &src_strides, + &[0, 0, 0, dim], + &[batches, num_heads, seq_len, head_dim], + )); } let (o0, o1, o2) = (outs[0].clone(), outs[1].clone(), outs[2].clone()); @@ -201,5 +216,6 @@ fn rope(src: Vec, shape: &Shape, dim: usize, base: f32, offset: usize) -> V let dst_shape = shape![num_heads, seq_len, head_dim]; let mut dst = vec![0.0f32; dst_shape.numel()]; concat(to_cat.as_slice(), 2, &dst_shape, &mut dst).unwrap(); + println!("CONCAT: {:?}", dst); dst } diff --git a/crates/ratchet-core/src/ops/rope.rs b/crates/ratchet-core/src/ops/rope.rs index fc71ee9c..2a13370b 100644 --- a/crates/ratchet-core/src/ops/rope.rs +++ b/crates/ratchet-core/src/ops/rope.rs @@ -302,7 +302,7 @@ def mlx_rope(input, dim, offset): } = problem; let shape = shape![BS, NH, SL, HD]; let n = shape.numel(); - let data = (0..n).map(|x| x as f32).collect::>(); + let data = (0..n).map(|x| x as f32 / 100.).collect::>(); let a = Tensor::from_data(data, shape, Device::CPU); println!("Input tensor: {:?}", a); let ground = ground_truth(&a, dim, offset).unwrap(); @@ -314,7 +314,7 @@ def mlx_rope(input, dim, offset): println!("ours = \n{:#?}\n", ours.to_ndarray_view::()); println!("ground = \n{:#?}", ground.to_ndarray_view::()); //Weak tolerance because of `ffast-math` - ground.all_close(&ours, 1e-3, 1e-3).unwrap(); + ground.all_close(&ours, 1e-2, 1e-2).unwrap(); } #[derive(Arbitrary, Debug)] @@ -335,7 +335,7 @@ def mlx_rope(input, dim, offset): offset: usize, } - #[proptest(cases = 16)] + #[proptest(cases = 8)] fn test_rope_gpu(prob: RoPEProblem) { let RoPEProblem { BS, @@ -362,8 +362,9 @@ def mlx_rope(input, dim, offset): SL, HD, dim, - offset, + mut offset, } = prob; + offset = 0; println!( "BS = {}, NH = {}, SL = {}, HD = {}, rope_dim = {}, offset = {}", BS, NH, SL, HD, dim, offset diff --git a/crates/ratchet-core/src/strides.rs b/crates/ratchet-core/src/strides.rs index 0762f6c4..11920ae0 100644 --- a/crates/ratchet-core/src/strides.rs +++ b/crates/ratchet-core/src/strides.rs @@ -25,6 +25,10 @@ impl Strides { } self.0.swap(rank - 2, rank - 1); } + + pub fn rank(&self) -> usize { + self.0.len() + } } impl std::fmt::Debug for Strides { From 572e7d19e4dee59c54ef6041d365faa86c46d0aa Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Wed, 16 Oct 2024 12:53:12 +0200 Subject: [PATCH 23/32] chore: rope concat dynamic outs length --- crates/ratchet-core/src/cpu/mod.rs | 4 ++-- crates/ratchet-core/src/cpu/rope.rs | 18 ++++++++++-------- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/crates/ratchet-core/src/cpu/mod.rs b/crates/ratchet-core/src/cpu/mod.rs index 3646373a..bdb63d21 100644 --- a/crates/ratchet-core/src/cpu/mod.rs +++ b/crates/ratchet-core/src/cpu/mod.rs @@ -362,7 +362,7 @@ pub fn cpu_cast(cast: Cast, dst: Tensor) -> Result { } pub(crate) fn concat( - inputs: &[(&Shape, Vec)], + inputs: &[(Shape, Vec)], dim: usize, dst_shape: &Shape, dst: &mut [T], @@ -398,7 +398,7 @@ pub(crate) fn apply_concat( let inputs = inputs .iter() .map(|t| match t.to_vec::() { - Ok(v) => Ok((t.shape(), v)), + Ok(v) => Ok((t.shape().clone(), v)), Err(e) => Err(e.into()), }) .collect::, OperationError>>(); diff --git a/crates/ratchet-core/src/cpu/rope.rs b/crates/ratchet-core/src/cpu/rope.rs index 6078a897..52bcfe42 100644 --- a/crates/ratchet-core/src/cpu/rope.rs +++ b/crates/ratchet-core/src/cpu/rope.rs @@ -1,3 +1,5 @@ +use std::borrow::Borrow; + use crate::{ concat, cpu::{cpu_store_result, gemm::gemm}, @@ -196,6 +198,7 @@ fn rope(src: Vec, shape: &Shape, dim: usize, base: f32, offset: usize) -> V println!("R1: {:?}", r1); println!("R2: {:?}", r2); + let mut to_cat = vec![]; if dim < shape[3] { outs.push(slice( &src, @@ -204,14 +207,13 @@ fn rope(src: Vec, shape: &Shape, dim: usize, base: f32, offset: usize) -> V &[batches, num_heads, seq_len, head_dim], )); } - - let (o0, o1, o2) = (outs[0].clone(), outs[1].clone(), outs[2].clone()); - - let to_cat = [ - (&shape![num_heads, seq_len, half_dim], o0), - (&shape![num_heads, seq_len, half_dim], o1), - (&shape![num_heads, seq_len, head_dim - dim], o2), - ]; + for i in 0..outs.len() - 1 { + to_cat.push((shape![num_heads, seq_len, half_dim], outs[i].clone())); + } + to_cat.push(( + shape![num_heads, seq_len, head_dim - dim], + outs[outs.len() - 1].clone(), + )); let dst_shape = shape![num_heads, seq_len, head_dim]; let mut dst = vec![0.0f32; dst_shape.numel()]; From ce991baea8147af4fb5da57507185e7f94c1d385 Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Wed, 16 Oct 2024 12:56:58 +0200 Subject: [PATCH 24/32] chore: simplify rope concat --- crates/ratchet-core/src/cpu/rope.rs | 13 +++++-------- crates/ratchet-core/src/ops/rope.rs | 6 +++--- 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/crates/ratchet-core/src/cpu/rope.rs b/crates/ratchet-core/src/cpu/rope.rs index 52bcfe42..cf9814f5 100644 --- a/crates/ratchet-core/src/cpu/rope.rs +++ b/crates/ratchet-core/src/cpu/rope.rs @@ -198,7 +198,10 @@ fn rope(src: Vec, shape: &Shape, dim: usize, base: f32, offset: usize) -> V println!("R1: {:?}", r1); println!("R2: {:?}", r2); - let mut to_cat = vec![]; + let mut to_cat = vec![ + (shape![num_heads, seq_len, half_dim], outs[0].clone()), + (shape![num_heads, seq_len, half_dim], outs[1].clone()), + ]; if dim < shape[3] { outs.push(slice( &src, @@ -206,14 +209,8 @@ fn rope(src: Vec, shape: &Shape, dim: usize, base: f32, offset: usize) -> V &[0, 0, 0, dim], &[batches, num_heads, seq_len, head_dim], )); + to_cat.push((shape![num_heads, seq_len, head_dim - dim], outs[2].clone())); } - for i in 0..outs.len() - 1 { - to_cat.push((shape![num_heads, seq_len, half_dim], outs[i].clone())); - } - to_cat.push(( - shape![num_heads, seq_len, head_dim - dim], - outs[outs.len() - 1].clone(), - )); let dst_shape = shape![num_heads, seq_len, head_dim]; let mut dst = vec![0.0f32; dst_shape.numel()]; diff --git a/crates/ratchet-core/src/ops/rope.rs b/crates/ratchet-core/src/ops/rope.rs index 2a13370b..134db03d 100644 --- a/crates/ratchet-core/src/ops/rope.rs +++ b/crates/ratchet-core/src/ops/rope.rs @@ -379,9 +379,9 @@ def mlx_rope(input, dim, offset): let prob = RoPEProblem { BS: 1, NH: 1, - SL: 2, - HD: 128, - dim: 96, + SL: 1, + HD: 32, + dim: 32, offset: 0, }; println!("{prob:?}"); From 67a40c9951e0d84a71dab024603e61195e595143 Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Fri, 18 Oct 2024 14:31:31 +0200 Subject: [PATCH 25/32] chore: padding r1/r2 with 0s works. Not optimal --- crates/ratchet-core/src/cpu/mod.rs | 3 +- crates/ratchet-core/src/cpu/rope.rs | 51 ++++++++++++++++++----------- crates/ratchet-core/src/ops/rope.rs | 10 +++--- 3 files changed, 38 insertions(+), 26 deletions(-) diff --git a/crates/ratchet-core/src/cpu/mod.rs b/crates/ratchet-core/src/cpu/mod.rs index bdb63d21..00a05545 100644 --- a/crates/ratchet-core/src/cpu/mod.rs +++ b/crates/ratchet-core/src/cpu/mod.rs @@ -375,7 +375,6 @@ pub(crate) fn concat( for (src_s, src) in inputs { let a_dim: usize = src_s.iter().take(dim).product(); let b_dim = block * src_s[dim]; - for idx in 0..a_dim { let dst_idx = idx * dst_s + dst_o; let src_idx = idx * b_dim + src_o; @@ -403,7 +402,7 @@ pub(crate) fn apply_concat( }) .collect::, OperationError>>(); - concat::(&inputs?, dim, dst.shape(), &mut result)?; + concat(&inputs?, dim, dst.shape(), &mut result)?; cpu_store_result(&dst, &result); Ok(dst) } diff --git a/crates/ratchet-core/src/cpu/rope.rs b/crates/ratchet-core/src/cpu/rope.rs index cf9814f5..ecdafd77 100644 --- a/crates/ratchet-core/src/cpu/rope.rs +++ b/crates/ratchet-core/src/cpu/rope.rs @@ -1,5 +1,3 @@ -use std::borrow::Borrow; - use crate::{ concat, cpu::{cpu_store_result, gemm::gemm}, @@ -125,20 +123,20 @@ fn transpose( } fn rope(src: Vec, shape: &Shape, dim: usize, base: f32, offset: usize) -> Vec { - println!("Ratchet RoPE"); + //println!("Ratchet RoPE"); let [batches, num_heads, seq_len, head_dim] = shape.try_into().unwrap(); let half_dim = dim / 2; let theta = compute_theta(dim, seq_len, base, offset); - println!("Theta: {:?}", theta); + //println!("Theta: {:?}", theta); let (sin, cos): (Vec, Vec) = theta.iter().map(|i| i.sin_cos()).unzip(); - println!("Cos: {:?}", cos); - println!("Sin: {:?}", sin); + //println!("Cos: {:?}", cos); + //println!("Sin: {:?}", sin); - println!("Cos length: {:?}", cos.len()); - println!("Sin length: {:?}", sin.len()); + //println!("Cos length: {:?}", cos.len()); + //println!("Sin length: {:?}", sin.len()); - println!("HALF DIM: {:?}", half_dim); + //println!("HALF DIM: {:?}", half_dim); let src_strides = Strides::from(shape); let x1 = slice( &src, @@ -152,10 +150,12 @@ fn rope(src: Vec, shape: &Shape, dim: usize, base: f32, offset: usize) -> V &[0, 0, 0, half_dim], &[batches, num_heads, seq_len, dim], ); + /* println!("X1: {:?}", x1); println!("X1 length: {:?}", x1.len()); println!("X2: {:?}", x2); println!("X2 length: {:?}", x2.len()); + */ //zip and repeat //`multiply` as an operation that deals with broadcasting @@ -171,11 +171,12 @@ fn rope(src: Vec, shape: &Shape, dim: usize, base: f32, offset: usize) -> V .collect::>(); let mut outs = vec![]; - let r1 = x1_cos + let mut r1 = x1_cos .iter() .zip(x2_sin.iter()) .map(|(x1, x2)| x1 - x2) .collect::>(); + r1.extend(vec![0.0; shape.numel() - r1.len()]); outs.push(r1.clone()); let x1_sin = x1 @@ -188,19 +189,26 @@ fn rope(src: Vec, shape: &Shape, dim: usize, base: f32, offset: usize) -> V .enumerate() .map(|(i, x)| x * cos[i % cos.len()]) .collect::>(); - let r2 = x1_sin + let mut r2 = x1_sin .iter() .zip(x2_cos.iter()) .map(|(x1, x2)| x1 + x2) .collect::>(); + r2.extend(vec![0.0; shape.numel() - r2.len()]); outs.push(r2.clone()); - println!("R1: {:?}", r1); - println!("R2: {:?}", r2); + //println!("R1: {:?}", r1); + //println!("R2: {:?}", r2); let mut to_cat = vec![ - (shape![num_heads, seq_len, half_dim], outs[0].clone()), - (shape![num_heads, seq_len, half_dim], outs[1].clone()), + ( + shape![batches, num_heads, seq_len, half_dim], + outs[0].clone(), + ), + ( + shape![batches, num_heads, seq_len, half_dim], + outs[1].clone(), + ), ]; if dim < shape[3] { outs.push(slice( @@ -209,12 +217,17 @@ fn rope(src: Vec, shape: &Shape, dim: usize, base: f32, offset: usize) -> V &[0, 0, 0, dim], &[batches, num_heads, seq_len, head_dim], )); - to_cat.push((shape![num_heads, seq_len, head_dim - dim], outs[2].clone())); + to_cat.push(( + shape![batches, num_heads, seq_len, head_dim - dim], + outs[2].clone(), + )); } - let dst_shape = shape![num_heads, seq_len, head_dim]; + let dst_shape = shape![batches, num_heads, seq_len, head_dim]; let mut dst = vec![0.0f32; dst_shape.numel()]; - concat(to_cat.as_slice(), 2, &dst_shape, &mut dst).unwrap(); - println!("CONCAT: {:?}", dst); + //println!("TO CONCAT size: {:?}", dst_shape.numel()); + //println!("TO CONCAT: {:?}", to_cat); + concat(to_cat.as_slice(), 3, &dst_shape, &mut dst).unwrap(); + //println!("CONCAT: {:?}", dst); dst } diff --git a/crates/ratchet-core/src/ops/rope.rs b/crates/ratchet-core/src/ops/rope.rs index 134db03d..265e6e69 100644 --- a/crates/ratchet-core/src/ops/rope.rs +++ b/crates/ratchet-core/src/ops/rope.rs @@ -377,12 +377,12 @@ def mlx_rope(input, dim, offset): #[test] fn debug_rope_cpu() { let prob = RoPEProblem { - BS: 1, - NH: 1, - SL: 1, + BS: 2, + NH: 16, + SL: 128, HD: 32, - dim: 32, - offset: 0, + dim: 16, + offset: 8, }; println!("{prob:?}"); From c932fd5ef5b81b371e9bd2ec48709089f6da2856 Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Sun, 20 Oct 2024 13:10:46 +0200 Subject: [PATCH 26/32] chore: use randn in rope test to avoid precision issues --- crates/ratchet-core/src/cpu/rope.rs | 21 --------------------- crates/ratchet-core/src/ops/rope.rs | 19 ++++++++----------- 2 files changed, 8 insertions(+), 32 deletions(-) diff --git a/crates/ratchet-core/src/cpu/rope.rs b/crates/ratchet-core/src/cpu/rope.rs index ecdafd77..b6a590ae 100644 --- a/crates/ratchet-core/src/cpu/rope.rs +++ b/crates/ratchet-core/src/cpu/rope.rs @@ -123,20 +123,11 @@ fn transpose( } fn rope(src: Vec, shape: &Shape, dim: usize, base: f32, offset: usize) -> Vec { - //println!("Ratchet RoPE"); let [batches, num_heads, seq_len, head_dim] = shape.try_into().unwrap(); let half_dim = dim / 2; let theta = compute_theta(dim, seq_len, base, offset); - //println!("Theta: {:?}", theta); let (sin, cos): (Vec, Vec) = theta.iter().map(|i| i.sin_cos()).unzip(); - //println!("Cos: {:?}", cos); - //println!("Sin: {:?}", sin); - - //println!("Cos length: {:?}", cos.len()); - //println!("Sin length: {:?}", sin.len()); - - //println!("HALF DIM: {:?}", half_dim); let src_strides = Strides::from(shape); let x1 = slice( &src, @@ -150,12 +141,6 @@ fn rope(src: Vec, shape: &Shape, dim: usize, base: f32, offset: usize) -> V &[0, 0, 0, half_dim], &[batches, num_heads, seq_len, dim], ); - /* - println!("X1: {:?}", x1); - println!("X1 length: {:?}", x1.len()); - println!("X2: {:?}", x2); - println!("X2 length: {:?}", x2.len()); - */ //zip and repeat //`multiply` as an operation that deals with broadcasting @@ -197,9 +182,6 @@ fn rope(src: Vec, shape: &Shape, dim: usize, base: f32, offset: usize) -> V r2.extend(vec![0.0; shape.numel() - r2.len()]); outs.push(r2.clone()); - //println!("R1: {:?}", r1); - //println!("R2: {:?}", r2); - let mut to_cat = vec![ ( shape![batches, num_heads, seq_len, half_dim], @@ -225,9 +207,6 @@ fn rope(src: Vec, shape: &Shape, dim: usize, base: f32, offset: usize) -> V let dst_shape = shape![batches, num_heads, seq_len, head_dim]; let mut dst = vec![0.0f32; dst_shape.numel()]; - //println!("TO CONCAT size: {:?}", dst_shape.numel()); - //println!("TO CONCAT: {:?}", to_cat); concat(to_cat.as_slice(), 3, &dst_shape, &mut dst).unwrap(); - //println!("CONCAT: {:?}", dst); dst } diff --git a/crates/ratchet-core/src/ops/rope.rs b/crates/ratchet-core/src/ops/rope.rs index 265e6e69..4fed9c05 100644 --- a/crates/ratchet-core/src/ops/rope.rs +++ b/crates/ratchet-core/src/ops/rope.rs @@ -271,7 +271,7 @@ mod tests { use test_strategy::{proptest, Arbitrary}; use crate::test_util::run_py_prg; - use crate::{gemm, shape, Device, DeviceRequest, Shape, Strides, Tensor}; + use crate::{shape, Device, DeviceRequest, Tensor}; fn ground_truth(a: &Tensor, dim: usize, offset: usize) -> anyhow::Result { let prg = r#" @@ -301,10 +301,7 @@ def mlx_rope(input, dim, offset): offset, } = problem; let shape = shape![BS, NH, SL, HD]; - let n = shape.numel(); - let data = (0..n).map(|x| x as f32 / 100.).collect::>(); - let a = Tensor::from_data(data, shape, Device::CPU); - println!("Input tensor: {:?}", a); + let a = Tensor::randn::(shape, Device::CPU); let ground = ground_truth(&a, dim, offset).unwrap(); let a = a.to(&device).unwrap(); @@ -377,12 +374,12 @@ def mlx_rope(input, dim, offset): #[test] fn debug_rope_cpu() { let prob = RoPEProblem { - BS: 2, - NH: 16, - SL: 128, - HD: 32, - dim: 16, - offset: 8, + BS: 1, + NH: 5, + SL: 180, + HD: 112, + dim: 96, + offset: 141, }; println!("{prob:?}"); From 022db5ae228ebe939bdd6e50ce76a126e87d9a7e Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Sun, 20 Oct 2024 13:20:57 +0200 Subject: [PATCH 27/32] chore: remove redundant "outs" vec --- crates/ratchet-core/src/cpu/rope.rs | 29 ++++++++--------------------- crates/ratchet-core/src/ops/rope.rs | 4 ++-- 2 files changed, 10 insertions(+), 23 deletions(-) diff --git a/crates/ratchet-core/src/cpu/rope.rs b/crates/ratchet-core/src/cpu/rope.rs index b6a590ae..58f3d0cd 100644 --- a/crates/ratchet-core/src/cpu/rope.rs +++ b/crates/ratchet-core/src/cpu/rope.rs @@ -64,16 +64,15 @@ fn slice(src: &[f32], src_strides: &Strides, start: &[usize], stop: &[usize]) -> assert!(s < t); }); - let delta: Vec = stop.iter().zip(start.iter()).map(|(s, t)| s - t).collect(); - let dst_shape: Vec = delta.clone(); - let dst_numel: usize = delta.iter().product(); + let dst_shape: Vec = stop.iter().zip(start.iter()).map(|(s, t)| s - t).collect(); + let dst_numel: usize = dst_shape.iter().product(); let mut dst = vec![0.0; dst_numel]; for i in 0..dst_numel { let mut src_index = 0; let mut tmp = i; - for d in 0..delta.len() { + for d in 0..dst_shape.len() { let coord = tmp / dst_shape[d + 1..].iter().product::().max(1); tmp %= dst_shape[d + 1..].iter().product::().max(1); src_index += (coord + start[d]) * src_strides[d] as usize; @@ -155,14 +154,12 @@ fn rope(src: Vec, shape: &Shape, dim: usize, base: f32, offset: usize) -> V .map(|(i, x)| x * sin[i % sin.len()]) .collect::>(); - let mut outs = vec![]; let mut r1 = x1_cos .iter() .zip(x2_sin.iter()) .map(|(x1, x2)| x1 - x2) .collect::>(); r1.extend(vec![0.0; shape.numel() - r1.len()]); - outs.push(r1.clone()); let x1_sin = x1 .iter() @@ -180,29 +177,19 @@ fn rope(src: Vec, shape: &Shape, dim: usize, base: f32, offset: usize) -> V .map(|(x1, x2)| x1 + x2) .collect::>(); r2.extend(vec![0.0; shape.numel() - r2.len()]); - outs.push(r2.clone()); let mut to_cat = vec![ - ( - shape![batches, num_heads, seq_len, half_dim], - outs[0].clone(), - ), - ( - shape![batches, num_heads, seq_len, half_dim], - outs[1].clone(), - ), + (shape![batches, num_heads, seq_len, half_dim], r1), + (shape![batches, num_heads, seq_len, half_dim], r2), ]; if dim < shape[3] { - outs.push(slice( + let r3 = slice( &src, &src_strides, &[0, 0, 0, dim], &[batches, num_heads, seq_len, head_dim], - )); - to_cat.push(( - shape![batches, num_heads, seq_len, head_dim - dim], - outs[2].clone(), - )); + ); + to_cat.push((shape![batches, num_heads, seq_len, head_dim - dim], r3)); } let dst_shape = shape![batches, num_heads, seq_len, head_dim]; diff --git a/crates/ratchet-core/src/ops/rope.rs b/crates/ratchet-core/src/ops/rope.rs index 4fed9c05..49ab8a48 100644 --- a/crates/ratchet-core/src/ops/rope.rs +++ b/crates/ratchet-core/src/ops/rope.rs @@ -308,8 +308,8 @@ def mlx_rope(input, dim, offset): let b = a.rope(dim, 10000.0, offset).unwrap().resolve().unwrap(); let ours = b.to(&Device::CPU).unwrap(); - println!("ours = \n{:#?}\n", ours.to_ndarray_view::()); - println!("ground = \n{:#?}", ground.to_ndarray_view::()); + //println!("ours = \n{:#?}\n", ours.to_ndarray_view::()); + //println!("ground = \n{:#?}", ground.to_ndarray_view::()); //Weak tolerance because of `ffast-math` ground.all_close(&ours, 1e-2, 1e-2).unwrap(); } From 508b5edb90a8d151fc2f942635570044bc1ad320 Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Sun, 20 Oct 2024 13:30:35 +0200 Subject: [PATCH 28/32] chore: use iter cycle instead of % check --- crates/ratchet-core/src/cpu/rope.rs | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/crates/ratchet-core/src/cpu/rope.rs b/crates/ratchet-core/src/cpu/rope.rs index 58f3d0cd..5908020b 100644 --- a/crates/ratchet-core/src/cpu/rope.rs +++ b/crates/ratchet-core/src/cpu/rope.rs @@ -141,17 +141,16 @@ fn rope(src: Vec, shape: &Shape, dim: usize, base: f32, offset: usize) -> V &[batches, num_heads, seq_len, dim], ); - //zip and repeat //`multiply` as an operation that deals with broadcasting let x1_cos = x1 .iter() - .enumerate() - .map(|(i, x)| x * cos[i % cos.len()]) + .zip(cos.iter().cycle()) + .map(|(x, c)| x * c) .collect::>(); let x2_sin = x2 .iter() - .enumerate() - .map(|(i, x)| x * sin[i % sin.len()]) + .zip(sin.iter().cycle()) + .map(|(x, s)| x * s) .collect::>(); let mut r1 = x1_cos @@ -163,13 +162,13 @@ fn rope(src: Vec, shape: &Shape, dim: usize, base: f32, offset: usize) -> V let x1_sin = x1 .iter() - .enumerate() - .map(|(i, x)| x * sin[i % sin.len()]) + .zip(sin.iter().cycle()) + .map(|(x, s)| x * s) .collect::>(); let x2_cos = x2 .iter() - .enumerate() - .map(|(i, x)| x * cos[i % cos.len()]) + .zip(cos.iter().cycle()) + .map(|(x, c)| x * c) .collect::>(); let mut r2 = x1_sin .iter() From be77442b1bf44f413e1b3ec0869c96011e43e00a Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Mon, 28 Oct 2024 21:57:55 +0100 Subject: [PATCH 29/32] chore: remove unused strided iterator. may be useful later --- crates/ratchet-core/src/cpu/mod.rs | 71 ----------------------------- crates/ratchet-core/src/cpu/rope.rs | 40 +--------------- 2 files changed, 1 insertion(+), 110 deletions(-) diff --git a/crates/ratchet-core/src/cpu/mod.rs b/crates/ratchet-core/src/cpu/mod.rs index 80915891..7ed06c69 100644 --- a/crates/ratchet-core/src/cpu/mod.rs +++ b/crates/ratchet-core/src/cpu/mod.rs @@ -36,77 +36,6 @@ pub fn apply_operation(op: LazyOp, dst: Tensor) -> Result { - shape: &'a Shape, - strides: &'a Strides, - next_index: Option, - multi_index: Vec, -} - -impl<'a> StridedIterator<'a> { - pub fn new(shape: &'a Shape, strides: &'a Strides, start_offset: usize) -> Self { - Self { - shape, - strides, - next_index: if shape.numel() == 0 { - None - } else { - Some(start_offset) - }, - multi_index: vec![0; shape.len()], - } - } -} - -impl<'a> Iterator for StridedIterator<'a> { - type Item = usize; - - fn next(&mut self) -> Option { - let storage_index = match self.next_index { - None => return None, - Some(storage_index) => storage_index, - }; - let mut updated = false; - let mut next_storage_index = storage_index; - for ((multi_i, max_i), stride_i) in self - .multi_index - .iter_mut() - .zip(self.shape.iter()) - .zip(self.strides.iter()) - .rev() - { - let next_i = *multi_i + 1; - if next_i < *max_i { - *multi_i = next_i; - updated = true; - next_storage_index += *stride_i as usize; - break; - } else { - next_storage_index -= *multi_i * *stride_i as usize; - *multi_i = 0 - } - } - self.next_index = if updated { - Some(next_storage_index) - } else { - None - }; - Some(storage_index) - } -} - -impl<'a> From<(&'a Shape, &'a Strides)> for StridedIterator<'a> { - fn from((shape, strides): (&'a Shape, &'a Strides)) -> Self { - StridedIterator::new(shape, strides, 0) - } -} - -impl<'a> From<(&'a Shape, &'a Strides, usize)> for StridedIterator<'a> { - fn from((shape, strides, offset): (&'a Shape, &'a Strides, usize)) -> Self { - StridedIterator::new(shape, strides, offset) - } -} - pub trait CPUOperation: Operation { fn apply_cpu(&self, dst: Tensor) -> Result; } diff --git a/crates/ratchet-core/src/cpu/rope.rs b/crates/ratchet-core/src/cpu/rope.rs index 5908020b..bd818721 100644 --- a/crates/ratchet-core/src/cpu/rope.rs +++ b/crates/ratchet-core/src/cpu/rope.rs @@ -1,7 +1,7 @@ use crate::{ concat, cpu::{cpu_store_result, gemm::gemm}, - shape, DType, OperationError, RoPE, Shape, StridedIterator, Strides, Tensor, + shape, DType, OperationError, RoPE, Shape, Strides, Tensor, }; use anyhow::anyhow; @@ -83,44 +83,6 @@ fn slice(src: &[f32], src_strides: &Strides, start: &[usize], stop: &[usize]) -> dst } -// Generic transpose function -fn transpose( - src: Vec, - shape: &Shape, - dim1: usize, - dim2: usize, -) -> Result, OperationError> { - let rank = shape.rank(); - if dim1 == dim2 { - return Ok(src); - } - if rank <= dim1 || rank <= dim2 { - return Err(anyhow!("Invalid dimensions for transpose operation").into()); - } - let mut dims = shape.to_vec(); - let mut strides = Strides::from(shape).to_vec(); - println!("dims: {:?}", dims); - println!("strides: {:?}", strides); - dims.swap(dim1, dim2); - strides.swap(dim1, dim2); - println!("dims: {:?}", dims); - println!("strides: {:?}", strides); - - let shape_t = Shape::from(dims); - let strides_t = Strides::from(strides); - - let mut result = vec![0.0; src.len()]; - let strided_iter = StridedIterator::new(&shape_t, &strides_t, 0); - let strided_iter2 = StridedIterator::new(&shape_t, &strides_t, 0); - let indices = strided_iter2.collect::>(); - println!("indices: {:?}", indices); - for (index, dst_index) in strided_iter.enumerate() { - result[dst_index] = src[index]; - } - - Ok(result) -} - fn rope(src: Vec, shape: &Shape, dim: usize, base: f32, offset: usize) -> Vec { let [batches, num_heads, seq_len, head_dim] = shape.try_into().unwrap(); From 5a018ceb6a66e79a28114f53103e5b811b94015b Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Tue, 29 Oct 2024 09:55:59 +0100 Subject: [PATCH 30/32] chore: tidy up --- crates/ratchet-core/src/ops/rope.rs | 7 ++----- crates/ratchet-core/src/storage/cpu_buffer.rs | 2 +- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/crates/ratchet-core/src/ops/rope.rs b/crates/ratchet-core/src/ops/rope.rs index 49ab8a48..00d1c2c1 100644 --- a/crates/ratchet-core/src/ops/rope.rs +++ b/crates/ratchet-core/src/ops/rope.rs @@ -279,8 +279,6 @@ import mlx.core as mx import mlx.nn as nn import numpy as np -mx.set_default_device(mx.cpu) - def mlx_rope(input, dim, offset): rope = nn.RoPE(dim) mx_input = mx.array(input) @@ -332,7 +330,7 @@ def mlx_rope(input, dim, offset): offset: usize, } - #[proptest(cases = 8)] + #[proptest(cases = 16)] fn test_rope_gpu(prob: RoPEProblem) { let RoPEProblem { BS, @@ -359,9 +357,8 @@ def mlx_rope(input, dim, offset): SL, HD, dim, - mut offset, + offset, } = prob; - offset = 0; println!( "BS = {}, NH = {}, SL = {}, HD = {}, rope_dim = {}, offset = {}", BS, NH, SL, HD, dim, offset diff --git a/crates/ratchet-core/src/storage/cpu_buffer.rs b/crates/ratchet-core/src/storage/cpu_buffer.rs index 4e7fde9a..3be6f1ee 100644 --- a/crates/ratchet-core/src/storage/cpu_buffer.rs +++ b/crates/ratchet-core/src/storage/cpu_buffer.rs @@ -89,7 +89,7 @@ impl CPUBuffer { } pub fn from_slice(data: &[T], shape: &Shape) -> Self { - //assert_eq!(data.len(), shape.numel()); + assert_eq!(data.len(), shape.numel()); let bytes: &[u8] = bytemuck::cast_slice(data); Self::from_bytes(bytes, std::mem::align_of::()) } From 99a074ee58cbe6de3be6b7026277deedead87928 Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Tue, 29 Oct 2024 10:01:11 +0100 Subject: [PATCH 31/32] chore: ? > unwrap --- crates/ratchet-core/src/cpu/rope.rs | 28 +++++++++++++------ crates/ratchet-core/src/storage/cpu_buffer.rs | 4 +-- 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/crates/ratchet-core/src/cpu/rope.rs b/crates/ratchet-core/src/cpu/rope.rs index bd818721..f4cfa2b4 100644 --- a/crates/ratchet-core/src/cpu/rope.rs +++ b/crates/ratchet-core/src/cpu/rope.rs @@ -12,7 +12,7 @@ pub fn cpu_rope(op: RoPE, dst: Tensor) -> Result { let base = op.base(); let offset = op.offset(); let src = op.input().to_vec::()?; - let result = rope(src, op.input().shape(), dim, base, offset); + let result = rope(src, op.input().shape(), dim, base, offset)?; cpu_store_result(&dst, &result) } _ => todo!(), @@ -21,7 +21,12 @@ pub fn cpu_rope(op: RoPE, dst: Tensor) -> Result { Ok(dst) } -fn compute_theta(dim: usize, seq_len: usize, base: f32, offset: usize) -> Vec { +fn compute_theta( + dim: usize, + seq_len: usize, + base: f32, + offset: usize, +) -> Result, OperationError> { let half_dim = dim / 2; let positions = (offset..seq_len + offset) @@ -51,10 +56,9 @@ fn compute_theta(dim: usize, seq_len: usize, base: f32, offset: usize) -> Vec Vec { @@ -83,11 +87,17 @@ fn slice(src: &[f32], src_strides: &Strides, start: &[usize], stop: &[usize]) -> dst } -fn rope(src: Vec, shape: &Shape, dim: usize, base: f32, offset: usize) -> Vec { +fn rope( + src: Vec, + shape: &Shape, + dim: usize, + base: f32, + offset: usize, +) -> Result, OperationError> { let [batches, num_heads, seq_len, head_dim] = shape.try_into().unwrap(); let half_dim = dim / 2; - let theta = compute_theta(dim, seq_len, base, offset); + let theta = compute_theta(dim, seq_len, base, offset)?; let (sin, cos): (Vec, Vec) = theta.iter().map(|i| i.sin_cos()).unzip(); let src_strides = Strides::from(shape); let x1 = slice( @@ -155,6 +165,6 @@ fn rope(src: Vec, shape: &Shape, dim: usize, base: f32, offset: usize) -> V let dst_shape = shape![batches, num_heads, seq_len, head_dim]; let mut dst = vec![0.0f32; dst_shape.numel()]; - concat(to_cat.as_slice(), 3, &dst_shape, &mut dst).unwrap(); - dst + concat(to_cat.as_slice(), 3, &dst_shape, &mut dst)?; + Ok(dst) } diff --git a/crates/ratchet-core/src/storage/cpu_buffer.rs b/crates/ratchet-core/src/storage/cpu_buffer.rs index 3be6f1ee..466ca41c 100644 --- a/crates/ratchet-core/src/storage/cpu_buffer.rs +++ b/crates/ratchet-core/src/storage/cpu_buffer.rs @@ -1,12 +1,10 @@ use bytemuck::{NoUninit, Pod}; use half::f16; -use crate::{storage::DeviceStorage, Device, DeviceError, GPUBuffer, Shape, TensorDType}; +use crate::{storage::DeviceStorage, DType, Device, DeviceError, GPUBuffer, Shape, TensorDType}; use std::{alloc::Layout, fmt::Debug, mem::MaybeUninit, sync::Arc}; -use crate::DType; - #[derive(derive_new::new, Debug, PartialEq, Eq)] pub struct RawCPUBuffer(*mut u8, Layout); From b74e4b2f54adea937ba7256b49b368e0a5c7cf76 Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Tue, 29 Oct 2024 11:37:14 +0100 Subject: [PATCH 32/32] chore: Add back default debug_struct in Debug for Tensor impl --- crates/ratchet-core/src/tensor.rs | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/crates/ratchet-core/src/tensor.rs b/crates/ratchet-core/src/tensor.rs index 7e581bed..d8e0f0eb 100644 --- a/crates/ratchet-core/src/tensor.rs +++ b/crates/ratchet-core/src/tensor.rs @@ -86,7 +86,17 @@ impl std::fmt::Debug for Tensor { match self.device() { Device::CPU => match self.dt() { DType::F32 => self.to_ndarray_view::().fmt(f), - _ => unimplemented!("Debug not implemented for {:?}", self.dt()), + _ => { + let storage_fmt = self.storage().as_ref().map(|s| s.dump(self.dt(), false)); + let (id, op) = (self.id(), self.op()); + f.debug_struct("Tensor") + .field("id", &id) + .field("shape", &self.shape()) + .field("dt", &self.dt()) + .field("op", &op) + .field("storage", &storage_fmt) + .finish() + } }, Device::GPU(_) => { let storage_fmt = self.storage().as_ref().map(|s| s.dump(self.dt(), false));