diff --git a/crates/ratchet-core/src/cpu/rope.rs b/crates/ratchet-core/src/cpu/rope.rs index 58f3d0cd..5908020b 100644 --- a/crates/ratchet-core/src/cpu/rope.rs +++ b/crates/ratchet-core/src/cpu/rope.rs @@ -141,17 +141,16 @@ fn rope(src: Vec, shape: &Shape, dim: usize, base: f32, offset: usize) -> V &[batches, num_heads, seq_len, dim], ); - //zip and repeat //`multiply` as an operation that deals with broadcasting let x1_cos = x1 .iter() - .enumerate() - .map(|(i, x)| x * cos[i % cos.len()]) + .zip(cos.iter().cycle()) + .map(|(x, c)| x * c) .collect::>(); let x2_sin = x2 .iter() - .enumerate() - .map(|(i, x)| x * sin[i % sin.len()]) + .zip(sin.iter().cycle()) + .map(|(x, s)| x * s) .collect::>(); let mut r1 = x1_cos @@ -163,13 +162,13 @@ fn rope(src: Vec, shape: &Shape, dim: usize, base: f32, offset: usize) -> V let x1_sin = x1 .iter() - .enumerate() - .map(|(i, x)| x * sin[i % sin.len()]) + .zip(sin.iter().cycle()) + .map(|(x, s)| x * s) .collect::>(); let x2_cos = x2 .iter() - .enumerate() - .map(|(i, x)| x * cos[i % cos.len()]) + .zip(cos.iter().cycle()) + .map(|(x, c)| x * c) .collect::>(); let mut r2 = x1_sin .iter()