Skip to content

Commit

Permalink
chore: remove redundant "outs" vec
Browse files Browse the repository at this point in the history
  • Loading branch information
ivarflakstad committed Oct 20, 2024
1 parent c932fd5 commit 022db5a
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 23 deletions.
29 changes: 8 additions & 21 deletions crates/ratchet-core/src/cpu/rope.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,16 +64,15 @@ fn slice(src: &[f32], src_strides: &Strides, start: &[usize], stop: &[usize]) ->
assert!(s < t);
});

let delta: Vec<usize> = stop.iter().zip(start.iter()).map(|(s, t)| s - t).collect();
let dst_shape: Vec<usize> = delta.clone();
let dst_numel: usize = delta.iter().product();
let dst_shape: Vec<usize> = stop.iter().zip(start.iter()).map(|(s, t)| s - t).collect();
let dst_numel: usize = dst_shape.iter().product();

let mut dst = vec![0.0; dst_numel];

for i in 0..dst_numel {
let mut src_index = 0;
let mut tmp = i;
for d in 0..delta.len() {
for d in 0..dst_shape.len() {
let coord = tmp / dst_shape[d + 1..].iter().product::<usize>().max(1);
tmp %= dst_shape[d + 1..].iter().product::<usize>().max(1);
src_index += (coord + start[d]) * src_strides[d] as usize;
Expand Down Expand Up @@ -155,14 +154,12 @@ fn rope(src: Vec<f32>, shape: &Shape, dim: usize, base: f32, offset: usize) -> V
.map(|(i, x)| x * sin[i % sin.len()])
.collect::<Vec<f32>>();

let mut outs = vec![];
let mut r1 = x1_cos
.iter()
.zip(x2_sin.iter())
.map(|(x1, x2)| x1 - x2)
.collect::<Vec<f32>>();
r1.extend(vec![0.0; shape.numel() - r1.len()]);
outs.push(r1.clone());

let x1_sin = x1
.iter()
Expand All @@ -180,29 +177,19 @@ fn rope(src: Vec<f32>, shape: &Shape, dim: usize, base: f32, offset: usize) -> V
.map(|(x1, x2)| x1 + x2)
.collect::<Vec<f32>>();
r2.extend(vec![0.0; shape.numel() - r2.len()]);
outs.push(r2.clone());

let mut to_cat = vec![
(
shape![batches, num_heads, seq_len, half_dim],
outs[0].clone(),
),
(
shape![batches, num_heads, seq_len, half_dim],
outs[1].clone(),
),
(shape![batches, num_heads, seq_len, half_dim], r1),
(shape![batches, num_heads, seq_len, half_dim], r2),
];
if dim < shape[3] {
outs.push(slice(
let r3 = slice(
&src,
&src_strides,
&[0, 0, 0, dim],
&[batches, num_heads, seq_len, head_dim],
));
to_cat.push((
shape![batches, num_heads, seq_len, head_dim - dim],
outs[2].clone(),
));
);
to_cat.push((shape![batches, num_heads, seq_len, head_dim - dim], r3));
}

let dst_shape = shape![batches, num_heads, seq_len, head_dim];
Expand Down
4 changes: 2 additions & 2 deletions crates/ratchet-core/src/ops/rope.rs
Original file line number Diff line number Diff line change
Expand Up @@ -308,8 +308,8 @@ def mlx_rope(input, dim, offset):
let b = a.rope(dim, 10000.0, offset).unwrap().resolve().unwrap();

let ours = b.to(&Device::CPU).unwrap();
println!("ours = \n{:#?}\n", ours.to_ndarray_view::<f32>());
println!("ground = \n{:#?}", ground.to_ndarray_view::<f32>());
//println!("ours = \n{:#?}\n", ours.to_ndarray_view::<f32>());
//println!("ground = \n{:#?}", ground.to_ndarray_view::<f32>());
//Weak tolerance because of `ffast-math`
ground.all_close(&ours, 1e-2, 1e-2).unwrap();
}
Expand Down

0 comments on commit 022db5a

Please sign in to comment.