Skip to content

Commit

Permalink
chore: use iter cycle instead of % check
Browse files Browse the repository at this point in the history
  • Loading branch information
ivarflakstad committed Oct 20, 2024
1 parent 022db5a commit 508b5ed
Showing 1 changed file with 8 additions and 9 deletions.
17 changes: 8 additions & 9 deletions crates/ratchet-core/src/cpu/rope.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,17 +141,16 @@ fn rope(src: Vec<f32>, shape: &Shape, dim: usize, base: f32, offset: usize) -> V
&[batches, num_heads, seq_len, dim],
);

//zip and repeat
//`multiply` as an operation that deals with broadcasting
let x1_cos = x1
.iter()
.enumerate()
.map(|(i, x)| x * cos[i % cos.len()])
.zip(cos.iter().cycle())
.map(|(x, c)| x * c)
.collect::<Vec<f32>>();
let x2_sin = x2
.iter()
.enumerate()
.map(|(i, x)| x * sin[i % sin.len()])
.zip(sin.iter().cycle())
.map(|(x, s)| x * s)
.collect::<Vec<f32>>();

let mut r1 = x1_cos
Expand All @@ -163,13 +162,13 @@ fn rope(src: Vec<f32>, shape: &Shape, dim: usize, base: f32, offset: usize) -> V

let x1_sin = x1
.iter()
.enumerate()
.map(|(i, x)| x * sin[i % sin.len()])
.zip(sin.iter().cycle())
.map(|(x, s)| x * s)
.collect::<Vec<f32>>();
let x2_cos = x2
.iter()
.enumerate()
.map(|(i, x)| x * cos[i % cos.len()])
.zip(cos.iter().cycle())
.map(|(x, c)| x * c)
.collect::<Vec<f32>>();
let mut r2 = x1_sin
.iter()
Expand Down

0 comments on commit 508b5ed

Please sign in to comment.