Skip to content

Commit

Permalink
chore: ? > unwrap
Browse files Browse the repository at this point in the history
  • Loading branch information
ivarflakstad committed Oct 29, 2024
1 parent 5a018ce commit 99a074e
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 12 deletions.
28 changes: 19 additions & 9 deletions crates/ratchet-core/src/cpu/rope.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ pub fn cpu_rope(op: RoPE, dst: Tensor) -> Result<Tensor, OperationError> {
let base = op.base();
let offset = op.offset();
let src = op.input().to_vec::<f32>()?;
let result = rope(src, op.input().shape(), dim, base, offset);
let result = rope(src, op.input().shape(), dim, base, offset)?;
cpu_store_result(&dst, &result)
}
_ => todo!(),
Expand All @@ -21,7 +21,12 @@ pub fn cpu_rope(op: RoPE, dst: Tensor) -> Result<Tensor, OperationError> {
Ok(dst)
}

fn compute_theta(dim: usize, seq_len: usize, base: f32, offset: usize) -> Vec<f32> {
fn compute_theta(
dim: usize,
seq_len: usize,
base: f32,
offset: usize,
) -> Result<Vec<f32>, OperationError> {
let half_dim = dim / 2;

let positions = (offset..seq_len + offset)
Expand Down Expand Up @@ -51,10 +56,9 @@ fn compute_theta(dim: usize, seq_len: usize, base: f32, offset: usize) -> Vec<f3
seq_len,
half_dim,
1,
)
.unwrap();
)?;

theta
Ok(theta)
}

fn slice(src: &[f32], src_strides: &Strides, start: &[usize], stop: &[usize]) -> Vec<f32> {
Expand Down Expand Up @@ -83,11 +87,17 @@ fn slice(src: &[f32], src_strides: &Strides, start: &[usize], stop: &[usize]) ->
dst
}

fn rope(src: Vec<f32>, shape: &Shape, dim: usize, base: f32, offset: usize) -> Vec<f32> {
fn rope(
src: Vec<f32>,
shape: &Shape,
dim: usize,
base: f32,
offset: usize,
) -> Result<Vec<f32>, OperationError> {
let [batches, num_heads, seq_len, head_dim] = shape.try_into().unwrap();

let half_dim = dim / 2;
let theta = compute_theta(dim, seq_len, base, offset);
let theta = compute_theta(dim, seq_len, base, offset)?;
let (sin, cos): (Vec<f32>, Vec<f32>) = theta.iter().map(|i| i.sin_cos()).unzip();
let src_strides = Strides::from(shape);
let x1 = slice(
Expand Down Expand Up @@ -155,6 +165,6 @@ fn rope(src: Vec<f32>, shape: &Shape, dim: usize, base: f32, offset: usize) -> V

let dst_shape = shape![batches, num_heads, seq_len, head_dim];
let mut dst = vec![0.0f32; dst_shape.numel()];
concat(to_cat.as_slice(), 3, &dst_shape, &mut dst).unwrap();
dst
concat(to_cat.as_slice(), 3, &dst_shape, &mut dst)?;
Ok(dst)
}
4 changes: 1 addition & 3 deletions crates/ratchet-core/src/storage/cpu_buffer.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
use bytemuck::{NoUninit, Pod};
use half::f16;

use crate::{storage::DeviceStorage, Device, DeviceError, GPUBuffer, Shape, TensorDType};
use crate::{storage::DeviceStorage, DType, Device, DeviceError, GPUBuffer, Shape, TensorDType};

use std::{alloc::Layout, fmt::Debug, mem::MaybeUninit, sync::Arc};

use crate::DType;

#[derive(derive_new::new, Debug, PartialEq, Eq)]
pub struct RawCPUBuffer(*mut u8, Layout);

Expand Down

0 comments on commit 99a074e

Please sign in to comment.