Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor turbomind attention by precomputing cos/sin #2801

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 3 additions & 15 deletions src/turbomind/kernels/attention/attention_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,21 +57,9 @@ struct AttentionParams {
float inv_sqrt_dh;

// rotary embedding
int rotary_embedding_dim;
float rotary_embedding_base;
float rope_scaling_factor;
float attention_scaling;
int max_position_embeddings;
float rope_ti_scale; // used for linear RoPE scaling
// the following 3 parameters are used by llama3
float llama3_inv_scaling_factor;
float llama3_alpha;
float llama3_beta;
// the following are use by yarn
float yarn_ramp_inv_factor_div_2;
float yarn_ramp_inv_factor_mul_min;
float yarn_inv_scaling_factor;

float* cos_sin;
int rotary_embedding_dim;
int max_position_embeddings;
// log(n) attention
bool use_logn_attn;

Expand Down
44 changes: 22 additions & 22 deletions src/turbomind/kernels/attention/attention_universal.h
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,8 @@ struct AttentionUniversal {
Vec vec_K[1][ITER_C];
Vec vec_V[1][ITER_C];

Array<float, kVecSize> vec_cs[ITER_S][ITER_C]; // precomputed cos sin

const int2 offset = Map::get_offset(warp_id, lane_id);

// Load Q
Expand All @@ -217,36 +219,34 @@ struct AttentionUniversal {
Ldg(vec_V[0][c], &params.v[k_idx]);
}
}
if (params.cos_sin) {
float* cos_sin = params.cos_sin;
const int64_t index = qi * params.rotary_embedding_dim + di;
PRAGMA_UNROLL
for (int k = 0; k < kVecSize; k += 4) {
(float4&)vec_cs[s][c][k] = __ldg((const float4*)&cos_sin[index + k]);
}
}
}
}
}

ApplyBias(vec_Q, vec_K, vec_V, params, head_idx, kv_head_idx, offset);

const float rope_base = params.rope_theta ? params.rope_theta[batch_idx] : params.rotary_embedding_base;
PRAGMA_UNROLL
for (int c = 0; c < ITER_C; ++c) {
const int di = offset.x + c * Map::kDeltaC;
FastRoPE rope(di,
params.rotary_embedding_dim,
rope_base,
params.rope_ti_scale,
params.rope_scaling_factor,
params.llama3_inv_scaling_factor,
params.llama3_alpha,
params.llama3_beta,
params.yarn_ramp_inv_factor_div_2,
params.yarn_ramp_inv_factor_mul_min,
params.yarn_inv_scaling_factor,
params.attention_scaling,
std::integral_constant<int, kVecSize>{});
if (params.cos_sin) {
PrecomputeFastRoPE rope{};
PRAGMA_UNROLL
for (int s = 0; s < ITER_S; ++s) {
const int ti = (offset.y + s * Map::kDeltaS) / CTA_H + query_idx + history_len;
rope.apply(vec_Q[s][c], ti);
if constexpr (kProcessKV) {
if (s == 0) {
rope.apply(vec_K[0][c], ti);
PRAGMA_UNROLL
for (int c = 0; c < ITER_C; ++c) {
const int di = offset.x + c * Map::kDeltaC;
if (di < params.rotary_embedding_dim) {
rope.apply(vec_Q[s][c], vec_cs[s][c]);
if constexpr (kProcessKV) {
if (s == 0) {
rope.apply(vec_K[0][c], vec_cs[s][c]);
}
}
}
}
}
Expand Down
181 changes: 40 additions & 141 deletions src/turbomind/kernels/attention/kv_cache_utils_v2.cu

Large diffs are not rendered by default.

44 changes: 2 additions & 42 deletions src/turbomind/kernels/attention/kv_cache_utils_v2.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,8 @@ void invokeProcessKV_v2(char** blocks,
const int* cu_q_len,
const int* cu_k_len,
const int* cu_block_num,
const float* rope_base,
const float* cos_sin,
int rope_dim,
float rope_ti_scale,
float rope_scaling_factor,
float llama3_inv_scaling_factor,
float llama3_alpha,
float llama3_beta,
float yarn_ramp_inv_factor_div_2,
float yarn_ramp_inv_factor_mul_min,
float yarn_inv_scaling_factor,
float attention_scaling,
int64_t stride_b,
int64_t stride_c,
int64_t stride_h,
Expand All @@ -51,17 +42,8 @@ void invokeProcessKV_v2_(const AttentionParams<T>& params)
params.cu_q_len,
params.cu_k_len,
params.block_iter_params.cu_block_nums,
params.rope_theta,
params.cos_sin,
params.rotary_embedding_dim,
params.rope_ti_scale,
params.rope_scaling_factor,
params.llama3_inv_scaling_factor,
params.llama3_alpha,
params.llama3_beta,
params.yarn_ramp_inv_factor_div_2,
params.yarn_ramp_inv_factor_mul_min,
params.yarn_inv_scaling_factor,
params.attention_scaling,
0, // stride b
params.stride / params.size_per_head, // stride c
1, // stride h
Expand All @@ -82,17 +64,6 @@ void invokeFlattenKV_v2(T* k,
char** blocks,
const int* cu_k_len,
const int* cu_block_num,
const float* rope_base,
int rope_dim,
float rope_ti_scale,
float rope_scaling_factor,
float llama3_inv_scaling_factor,
float llama3_alpha,
float llama3_beta,
float yarn_ramp_inv_factor_div_2,
float yarn_ramp_inv_factor_mul_min,
float yarn_inv_scaling_factor,
float attention_scaling,
int64_t stride_b,
int64_t stride_c,
int64_t stride_h,
Expand All @@ -116,17 +87,6 @@ void invokeFlattenKV_v2_(const AttentionParams<T>& params, int sum_k_len)
(char**)params.block_iter_params.block_ptrs,
params.cu_k_len,
params.block_iter_params.cu_block_nums,
nullptr, // params.rope_theta,
params.rotary_embedding_dim,
params.rope_ti_scale,
params.rope_scaling_factor,
params.llama3_inv_scaling_factor,
params.llama3_alpha,
params.llama3_beta,
params.yarn_ramp_inv_factor_div_2,
params.yarn_ramp_inv_factor_mul_min,
params.yarn_inv_scaling_factor,
params.attention_scaling,
0,
1,
2 * sum_k_len,
Expand Down
15 changes: 15 additions & 0 deletions src/turbomind/kernels/attention/rotary_embedding.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,21 @@ __device__ void ApplyRotaryEmbedding(Array<T, 4>& x, float base, int dims, int t
}
}

struct PrecomputeFastRoPE {

template<typename T, int N>
__device__ void apply(Array<T, N>& x, Array<float, N>& cs)
{
PRAGMA_UNROLL
for (int i = 0; i < N; i += 2) {
float tmp0 = cs[i] * (float)x[i] - cs[i + 1] * (float)x[i + 1];
float tmp1 = cs[i] * (float)x[i + 1] + cs[i + 1] * (float)x[i];
x[i] = (T)tmp0;
x[i + 1] = (T)tmp1;
}
}
};

template<class D, int N>
struct FastRoPE {

Expand Down
35 changes: 1 addition & 34 deletions src/turbomind/kernels/attention/test_attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -148,15 +148,6 @@ void TestBlocks(const thrust::universal_vector<T>& k_cache, // [B, H, S,
cu_block_cnts.data().get(),
nullptr,
rope_dim,
1.,
0.,
0.,
1.0,
1.0,
0.0,
0.0,
0.0,
1.0,
2 * head_num * seq_len,
0,
seq_len,
Expand All @@ -180,17 +171,6 @@ void TestBlocks(const thrust::universal_vector<T>& k_cache, // [B, H, S,
k_ptrs.data().get(),
cu_seq_lens.data().get(),
cu_block_cnts.data().get(),
nullptr,
rope_dim,
1.,
0.,
0.,
1.0,
1.0,
0.0,
0.0,
0.0,
1.0,
2 * head_num * seq_len,
0,
seq_len,
Expand Down Expand Up @@ -435,9 +415,7 @@ int test_attention()
params.size_per_head = kHeadDim;
params.inv_sqrt_dh = (float)std::log2(expf(1.)) / std::sqrt((float)params.size_per_head);

params.rotary_embedding_dim = kRoPEDim;
params.rotary_embedding_base = kRoPEBase;
params.rope_ti_scale = 1.;
params.rotary_embedding_dim = kRoPEDim;

params.split_cnt = split_cnt.data().get();
params.partial_L = partial_L.data().get();
Expand Down Expand Up @@ -544,17 +522,6 @@ int test_attention()
k_ptrs.data().get(),
cu_kv_lens.data().get(),
cu_block_cnts.data().get(),
nullptr, // DECODING ? nullptr : params.rope_theta,
kRoPEDim,
1.,
0.,
0.,
1.0,
1.0,
0.0,
0.0,
0.0,
1.0,
KvHeadNum * kContextLen,
0,
kContextLen,
Expand Down
1 change: 1 addition & 0 deletions src/turbomind/models/llama/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ add_library(Llama STATIC
unified_attention_layer.cc
llama_kernels.cu
llama_decoder_kernels.cu
rotary_emb.cu
llama_utils.cu)
set_property(TARGET Llama PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET Llama PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
Expand Down
Loading
Loading