Skip to content

Commit

Permalink
chore: simplify rope concat
Browse files Browse the repository at this point in the history
  • Loading branch information
ivarflakstad committed Oct 16, 2024
1 parent 572e7d1 commit ce991ba
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 11 deletions.
13 changes: 5 additions & 8 deletions crates/ratchet-core/src/cpu/rope.rs
Original file line number Diff line number Diff line change
Expand Up @@ -198,22 +198,19 @@ fn rope(src: Vec<f32>, shape: &Shape, dim: usize, base: f32, offset: usize) -> V
println!("R1: {:?}", r1);
println!("R2: {:?}", r2);

let mut to_cat = vec![];
let mut to_cat = vec![
(shape![num_heads, seq_len, half_dim], outs[0].clone()),
(shape![num_heads, seq_len, half_dim], outs[1].clone()),
];
if dim < shape[3] {
outs.push(slice(
&src,
&src_strides,
&[0, 0, 0, dim],
&[batches, num_heads, seq_len, head_dim],
));
to_cat.push((shape![num_heads, seq_len, head_dim - dim], outs[2].clone()));
}
for i in 0..outs.len() - 1 {
to_cat.push((shape![num_heads, seq_len, half_dim], outs[i].clone()));
}
to_cat.push((
shape![num_heads, seq_len, head_dim - dim],
outs[outs.len() - 1].clone(),
));

let dst_shape = shape![num_heads, seq_len, head_dim];
let mut dst = vec![0.0f32; dst_shape.numel()];
Expand Down
6 changes: 3 additions & 3 deletions crates/ratchet-core/src/ops/rope.rs
Original file line number Diff line number Diff line change
Expand Up @@ -379,9 +379,9 @@ def mlx_rope(input, dim, offset):
let prob = RoPEProblem {
BS: 1,
NH: 1,
SL: 2,
HD: 128,
dim: 96,
SL: 1,
HD: 32,
dim: 32,
offset: 0,
};
println!("{prob:?}");
Expand Down

0 comments on commit ce991ba

Please sign in to comment.