From 508b5edb90a8d151fc2f942635570044bc1ad320 Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Sun, 20 Oct 2024 13:30:35 +0200 Subject: [PATCH] chore: use iter cycle instead of % check --- crates/ratchet-core/src/cpu/rope.rs | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/crates/ratchet-core/src/cpu/rope.rs b/crates/ratchet-core/src/cpu/rope.rs index 58f3d0cd..5908020b 100644 --- a/crates/ratchet-core/src/cpu/rope.rs +++ b/crates/ratchet-core/src/cpu/rope.rs @@ -141,17 +141,16 @@ fn rope(src: Vec, shape: &Shape, dim: usize, base: f32, offset: usize) -> V &[batches, num_heads, seq_len, dim], ); - //zip and repeat //`multiply` as an operation that deals with broadcasting let x1_cos = x1 .iter() - .enumerate() - .map(|(i, x)| x * cos[i % cos.len()]) + .zip(cos.iter().cycle()) + .map(|(x, c)| x * c) .collect::>(); let x2_sin = x2 .iter() - .enumerate() - .map(|(i, x)| x * sin[i % sin.len()]) + .zip(sin.iter().cycle()) + .map(|(x, s)| x * s) .collect::>(); let mut r1 = x1_cos @@ -163,13 +162,13 @@ fn rope(src: Vec, shape: &Shape, dim: usize, base: f32, offset: usize) -> V let x1_sin = x1 .iter() - .enumerate() - .map(|(i, x)| x * sin[i % sin.len()]) + .zip(sin.iter().cycle()) + .map(|(x, s)| x * s) .collect::>(); let x2_cos = x2 .iter() - .enumerate() - .map(|(i, x)| x * cos[i % cos.len()]) + .zip(cos.iter().cycle()) + .map(|(x, c)| x * c) .collect::>(); let mut r2 = x1_sin .iter()