diff --git a/src/turbomind/kernels/attention/attention_params.h b/src/turbomind/kernels/attention/attention_params.h index b6dfaa596..7bc770ab6 100644 --- a/src/turbomind/kernels/attention/attention_params.h +++ b/src/turbomind/kernels/attention/attention_params.h @@ -57,21 +57,9 @@ struct AttentionParams { float inv_sqrt_dh; // rotary embedding - int rotary_embedding_dim; - float rotary_embedding_base; - float rope_scaling_factor; - float attention_scaling; - int max_position_embeddings; - float rope_ti_scale; // used for linear RoPE scaling - // the following 3 parameters are used by llama3 - float llama3_inv_scaling_factor; - float llama3_alpha; - float llama3_beta; - // the following are use by yarn - float yarn_ramp_inv_factor_div_2; - float yarn_ramp_inv_factor_mul_min; - float yarn_inv_scaling_factor; - + float* cos_sin; + int rotary_embedding_dim; + int max_position_embeddings; // log(n) attention bool use_logn_attn; diff --git a/src/turbomind/kernels/attention/attention_universal.h b/src/turbomind/kernels/attention/attention_universal.h index 5fb583bd1..67b72f200 100644 --- a/src/turbomind/kernels/attention/attention_universal.h +++ b/src/turbomind/kernels/attention/attention_universal.h @@ -194,6 +194,8 @@ struct AttentionUniversal { Vec vec_K[1][ITER_C]; Vec vec_V[1][ITER_C]; + Array vec_cs[ITER_S][ITER_C]; // precomputed cos sin + const int2 offset = Map::get_offset(warp_id, lane_id); // Load Q @@ -217,36 +219,34 @@ struct AttentionUniversal { Ldg(vec_V[0][c], ¶ms.v[k_idx]); } } + if (params.cos_sin) { + float* cos_sin = params.cos_sin; + const int64_t index = qi * params.rotary_embedding_dim + di; + PRAGMA_UNROLL + for (int k = 0; k < kVecSize; k += 4) { + (float4&)vec_cs[s][c][k] = __ldg((const float4*)&cos_sin[index + k]); + } + } } } } ApplyBias(vec_Q, vec_K, vec_V, params, head_idx, kv_head_idx, offset); - const float rope_base = params.rope_theta ? params.rope_theta[batch_idx] : params.rotary_embedding_base; - PRAGMA_UNROLL - for (int c = 0; c < ITER_C; ++c) { - const int di = offset.x + c * Map::kDeltaC; - FastRoPE rope(di, - params.rotary_embedding_dim, - rope_base, - params.rope_ti_scale, - params.rope_scaling_factor, - params.llama3_inv_scaling_factor, - params.llama3_alpha, - params.llama3_beta, - params.yarn_ramp_inv_factor_div_2, - params.yarn_ramp_inv_factor_mul_min, - params.yarn_inv_scaling_factor, - params.attention_scaling, - std::integral_constant{}); + if (params.cos_sin) { + PrecomputeFastRoPE rope{}; PRAGMA_UNROLL for (int s = 0; s < ITER_S; ++s) { - const int ti = (offset.y + s * Map::kDeltaS) / CTA_H + query_idx + history_len; - rope.apply(vec_Q[s][c], ti); - if constexpr (kProcessKV) { - if (s == 0) { - rope.apply(vec_K[0][c], ti); + PRAGMA_UNROLL + for (int c = 0; c < ITER_C; ++c) { + const int di = offset.x + c * Map::kDeltaC; + if (di < params.rotary_embedding_dim) { + rope.apply(vec_Q[s][c], vec_cs[s][c]); + if constexpr (kProcessKV) { + if (s == 0) { + rope.apply(vec_K[0][c], vec_cs[s][c]); + } + } } } } diff --git a/src/turbomind/kernels/attention/kv_cache_utils_v2.cu b/src/turbomind/kernels/attention/kv_cache_utils_v2.cu index f2e2faef9..926c13fbd 100644 --- a/src/turbomind/kernels/attention/kv_cache_utils_v2.cu +++ b/src/turbomind/kernels/attention/kv_cache_utils_v2.cu @@ -20,17 +20,8 @@ __global__ void __launch_bounds__(128) ProcessKV_v2(char** blocks, const int* cu_q_len, const int* cu_k_len, const int* cu_block_num, - const float* rope_base, + const float* cos_sin, int rope_dim, - float rope_ti_scale, - float rope_scaling_factor, - float llama3_inv_scaling_factor, - float llama3_alpha, - float llama3_beta, - float yarn_ramp_inv_factor_div_2, - float yarn_ramp_inv_factor_mul_min, - float yarn_inv_scaling_factor, - float attention_scaling, int64_t stride_b, int64_t stride_c, int64_t stride_h, @@ -124,28 +115,33 @@ __global__ void __launch_bounds__(128) ProcessKV_v2(char** blocks, } } - if (rope_base) { - float base = rope_base[batch_idx]; + if (cos_sin) { + Array vec_cs[ITER_S][ITER_C]; PRAGMA_UNROLL - for (int c = 0; c < ITER_C; ++c) { - const int di = offset.x + c * Map::kDeltaC; - FastRoPE rope(di, - rope_dim, - base, - rope_ti_scale, - rope_scaling_factor, - llama3_inv_scaling_factor, - llama3_alpha, - llama3_beta, - yarn_ramp_inv_factor_div_2, - yarn_ramp_inv_factor_mul_min, - yarn_inv_scaling_factor, - attention_scaling, - std::integral_constant{}); + for (int s = 0; s < ITER_S; ++s) { + const int qi = offset.y + s * Map::kDeltaS + token_idx; PRAGMA_UNROLL - for (int s = 0; s < ITER_S; ++s) { - const int ti = history_len + offset.y + s * Map::kDeltaS + token_idx; // sequence local - rope.apply(vec_K[s][c], ti); + for (int c = 0; c < ITER_C; ++c) { + const int di = offset.x + c * Map::kDeltaC; + const int64_t index = (qi_beg + qi) * rope_dim + di; + PRAGMA_UNROLL + for (int k = 0; k < kVecSize; k += 4) { + if (qi < q_len) { + (float4&)vec_cs[s][c][k] = __ldg((const float4*)&cos_sin[index + k]); + } + } + } + } + + PrecomputeFastRoPE rope; + PRAGMA_UNROLL + for (int s = 0; s < ITER_S; ++s) { + PRAGMA_UNROLL + for (int c = 0; c < ITER_C; ++c) { + const int di = offset.x + c * Map::kDeltaC; + if (di < rope_dim) { + rope.apply(vec_K[s][c], vec_cs[s][c]); + } } } } @@ -211,17 +207,8 @@ void invokeProcessKV_v2(char** blocks, const int* cu_q_len, const int* cu_k_len, const int* cu_block_num, - const float* rope_base, + const float* cos_sin, int rope_dim, - float rope_ti_scale, - float rope_scaling_factor, - float llama3_inv_scaling_factor, - float llama3_1_alpha, - float llama3_1_beta, - float yarn_ramp_inv_factor_div_2, - float yarn_ramp_inv_factor_mul_min, - float yarn_inv_scaling_factor, - float attention_scaling, int64_t stride_b, int64_t stride_c, int64_t stride_h, @@ -257,17 +244,8 @@ void invokeProcessKV_v2(char** blocks, cu_q_len, cu_k_len, cu_block_num, - rope_base, + cos_sin, rope_dim, - rope_ti_scale, - rope_scaling_factor, - llama3_inv_scaling_factor, - llama3_1_alpha, - llama3_1_beta, - yarn_ramp_inv_factor_div_2, - yarn_ramp_inv_factor_mul_min, - yarn_inv_scaling_factor, - attention_scaling, stride_b, stride_c, stride_h, @@ -309,17 +287,8 @@ void invokeProcessKV_v2(char** blocks, const int* cu_q_len, \ const int* cu_k_len, \ const int* cu_block_num, \ - const float* rope_base, \ + const float* cos_sin, \ int rope_dim, \ - float rope_ti_scale, \ - float rope_scaling_factor, \ - float llama3_inv_scaling_factor, \ - float llama3_1_alpha, \ - float llama3_1_beta, \ - float yarn_ramp_inv_factor_div_2, \ - float yarn_ramp_inv_factor_mul_min, \ - float yarn_inv_scaling_factor, \ - float attention_scaling, \ int64_t stride_b, \ int64_t stride_c, \ int64_t stride_h, \ @@ -339,28 +308,17 @@ INSTANTIATE_invokeProcessKV_v2(nv_bfloat16); #endif template -__global__ void __launch_bounds__(128) flattenKV_v2(T* k, - T* v, - const Tkv** blocks, - const int* cu_k_len, - const int* cu_block_num, - const float* rope_base, - int rope_dim, - float rope_ti_scale, - float rope_scaling_factor, - float llama3_inv_scaling_factor, - float llama3_alpha, - float llama3_beta, - float yarn_ramp_inv_factor_div_2, - float yarn_ramp_inv_factor_mul_min, - float yarn_inv_scaling_factor, - float attention_scaling, - int64_t stride_b, - int64_t stride_c, - int64_t stride_h, - int64_t stride_s, - int layer_id, - BlockLayout block_layout) +__global__ void __launch_bounds__(128) flattenKV_v2(T* k, + T* v, + const Tkv** blocks, + const int* cu_k_len, + const int* cu_block_num, + int64_t stride_b, + int64_t stride_c, + int64_t stride_h, + int64_t stride_s, + int layer_id, + BlockLayout block_layout) { constexpr int kVecSize = sizeof(uint4) / sizeof(T); @@ -431,32 +389,6 @@ __global__ void __launch_bounds__(128) flattenKV_v2(T* k, } } - if (rope_base) { - float base = rope_base[batch_idx]; - PRAGMA_UNROLL - for (int c = 0; c < ITER_C; ++c) { - const int di = offset.x + c * Map::kDeltaC; - FastRoPE rope(di, - rope_dim, - base, - rope_ti_scale, - rope_scaling_factor, - llama3_inv_scaling_factor, - llama3_alpha, - llama3_beta, - yarn_ramp_inv_factor_div_2, - yarn_ramp_inv_factor_mul_min, - yarn_inv_scaling_factor, - attention_scaling, - std::integral_constant{}); - PRAGMA_UNROLL - for (int s = 0; s < ITER_S; ++s) { - const int ti = offset.y + s * Map::kDeltaS + token_idx; // sequence local - rope.apply(out_K[s][c], ti); - } - } - } - PRAGMA_UNROLL for (int s = 0; s < ITER_S; ++s) { PRAGMA_UNROLL @@ -479,17 +411,6 @@ void invokeFlattenKV_v2(T* k, char** blocks, const int* cu_k_len, const int* cu_block_num, - const float* rope_base, - int rope_dim, - float rope_ti_scale, - float rope_scaling_factor, - float llama3_inv_scaling_factor, - float llama3_alpha, - float llama3_beta, - float yarn_ramp_inv_factor_div_2, - float yarn_ramp_inv_factor_mul_min, - float yarn_inv_scaling_factor, - float attention_scaling, int64_t stride_b, int64_t stride_c, int64_t stride_h, @@ -522,17 +443,6 @@ void invokeFlattenKV_v2(T* k, (const Tkv**)blocks, cu_k_len, cu_block_num, - rope_base, - rope_dim, - rope_ti_scale, - rope_scaling_factor, - llama3_inv_scaling_factor, - llama3_alpha, - llama3_beta, - yarn_ramp_inv_factor_div_2, - yarn_ramp_inv_factor_mul_min, - yarn_inv_scaling_factor, - attention_scaling, stride_b, stride_c, stride_h, @@ -571,17 +481,6 @@ void invokeFlattenKV_v2(T* k, char** blocks, \ const int* cu_k_len, \ const int* cu_block_num, \ - const float* rope_base, \ - int rope_dim, \ - float rope_ti_scale, \ - float rope_scaling_factor, \ - float llama3_inv_scaling_factor, \ - float llama3_alpha, \ - float llama3_beta, \ - float yarn_ramp_inv_factor_div_2, \ - float yarn_ramp_inv_factor_mul_min, \ - float yarn_inv_scaling_factor, \ - float attention_scaling, \ int64_t stride_b, \ int64_t stride_c, \ int64_t stride_h, \ diff --git a/src/turbomind/kernels/attention/kv_cache_utils_v2.h b/src/turbomind/kernels/attention/kv_cache_utils_v2.h index fe45ad7be..0f789fb58 100644 --- a/src/turbomind/kernels/attention/kv_cache_utils_v2.h +++ b/src/turbomind/kernels/attention/kv_cache_utils_v2.h @@ -16,17 +16,8 @@ void invokeProcessKV_v2(char** blocks, const int* cu_q_len, const int* cu_k_len, const int* cu_block_num, - const float* rope_base, + const float* cos_sin, int rope_dim, - float rope_ti_scale, - float rope_scaling_factor, - float llama3_inv_scaling_factor, - float llama3_alpha, - float llama3_beta, - float yarn_ramp_inv_factor_div_2, - float yarn_ramp_inv_factor_mul_min, - float yarn_inv_scaling_factor, - float attention_scaling, int64_t stride_b, int64_t stride_c, int64_t stride_h, @@ -51,17 +42,8 @@ void invokeProcessKV_v2_(const AttentionParams& params) params.cu_q_len, params.cu_k_len, params.block_iter_params.cu_block_nums, - params.rope_theta, + params.cos_sin, params.rotary_embedding_dim, - params.rope_ti_scale, - params.rope_scaling_factor, - params.llama3_inv_scaling_factor, - params.llama3_alpha, - params.llama3_beta, - params.yarn_ramp_inv_factor_div_2, - params.yarn_ramp_inv_factor_mul_min, - params.yarn_inv_scaling_factor, - params.attention_scaling, 0, // stride b params.stride / params.size_per_head, // stride c 1, // stride h @@ -82,17 +64,6 @@ void invokeFlattenKV_v2(T* k, char** blocks, const int* cu_k_len, const int* cu_block_num, - const float* rope_base, - int rope_dim, - float rope_ti_scale, - float rope_scaling_factor, - float llama3_inv_scaling_factor, - float llama3_alpha, - float llama3_beta, - float yarn_ramp_inv_factor_div_2, - float yarn_ramp_inv_factor_mul_min, - float yarn_inv_scaling_factor, - float attention_scaling, int64_t stride_b, int64_t stride_c, int64_t stride_h, @@ -116,17 +87,6 @@ void invokeFlattenKV_v2_(const AttentionParams& params, int sum_k_len) (char**)params.block_iter_params.block_ptrs, params.cu_k_len, params.block_iter_params.cu_block_nums, - nullptr, // params.rope_theta, - params.rotary_embedding_dim, - params.rope_ti_scale, - params.rope_scaling_factor, - params.llama3_inv_scaling_factor, - params.llama3_alpha, - params.llama3_beta, - params.yarn_ramp_inv_factor_div_2, - params.yarn_ramp_inv_factor_mul_min, - params.yarn_inv_scaling_factor, - params.attention_scaling, 0, 1, 2 * sum_k_len, diff --git a/src/turbomind/kernels/attention/rotary_embedding.h b/src/turbomind/kernels/attention/rotary_embedding.h index db836ed18..77045290f 100644 --- a/src/turbomind/kernels/attention/rotary_embedding.h +++ b/src/turbomind/kernels/attention/rotary_embedding.h @@ -67,6 +67,21 @@ __device__ void ApplyRotaryEmbedding(Array& x, float base, int dims, int t } } +struct PrecomputeFastRoPE { + + template + __device__ void apply(Array& x, Array& cs) + { + PRAGMA_UNROLL + for (int i = 0; i < N; i += 2) { + float tmp0 = cs[i] * (float)x[i] - cs[i + 1] * (float)x[i + 1]; + float tmp1 = cs[i] * (float)x[i + 1] + cs[i + 1] * (float)x[i]; + x[i] = (T)tmp0; + x[i + 1] = (T)tmp1; + } + } +}; + template struct FastRoPE { diff --git a/src/turbomind/kernels/attention/test_attention.cu b/src/turbomind/kernels/attention/test_attention.cu index 804d4815d..788b5d489 100644 --- a/src/turbomind/kernels/attention/test_attention.cu +++ b/src/turbomind/kernels/attention/test_attention.cu @@ -7,6 +7,7 @@ #include "src/turbomind/kernels/attention/attention_params.h" #include "src/turbomind/kernels/attention/reference.h" #include "src/turbomind/models/llama/llama_utils.h" +#include "src/turbomind/models/llama/rotary_emb.h" #include "src/turbomind/utils/cuda_utils.h" #include "test_utils.h" #include @@ -148,15 +149,6 @@ void TestBlocks(const thrust::universal_vector& k_cache, // [B, H, S, cu_block_cnts.data().get(), nullptr, rope_dim, - 1., - 0., - 0., - 1.0, - 1.0, - 0.0, - 0.0, - 0.0, - 1.0, 2 * head_num * seq_len, 0, seq_len, @@ -180,17 +172,6 @@ void TestBlocks(const thrust::universal_vector& k_cache, // [B, H, S, k_ptrs.data().get(), cu_seq_lens.data().get(), cu_block_cnts.data().get(), - nullptr, - rope_dim, - 1., - 0., - 0., - 1.0, - 1.0, - 0.0, - 0.0, - 0.0, - 1.0, 2 * head_num * seq_len, 0, seq_len, @@ -396,6 +377,26 @@ int test_attention() rope_base[i] = kRoPEBase; } + // precompute cos/sin + const int device_id = 0; + auto allocator = std::make_unique>(device_id, false); + allocator->setStream(nullptr); + AttentionParam attn_param; + attn_param.rope.type = RopeType::kDefault; + attn_param.rope.base = kRoPEBase; + attn_param.rope.dim = kRoPEDim; + attn_param.rope.factor = 1.0f; + auto rotary_emb = std::make_unique(attn_param, nullptr, allocator.get()); + + RotaryEmbeddingV2Param rotary_param; + rotary_param.rope_theta = rope_base.data().get(); + rotary_param.q_len = cu_seqlens.data().get(); + rotary_param.k_ken = cu_kv_lens.data().get(); + rotary_param.batch_size = kBatchSize; + rotary_param.token_num = kTokenNum; + rotary_emb->forward(rotary_param); + params.cos_sin = rotary_emb->cos_sin_; + // getchar(); params.out = output_ref.data().get(); @@ -435,9 +436,7 @@ int test_attention() params.size_per_head = kHeadDim; params.inv_sqrt_dh = (float)std::log2(expf(1.)) / std::sqrt((float)params.size_per_head); - params.rotary_embedding_dim = kRoPEDim; - params.rotary_embedding_base = kRoPEBase; - params.rope_ti_scale = 1.; + params.rotary_embedding_dim = kRoPEDim; params.split_cnt = split_cnt.data().get(); params.partial_L = partial_L.data().get(); @@ -451,10 +450,6 @@ int test_attention() params.qk = qk_buf.data().get(); params.pr = pr_buf.data().get(); - params.attention_scaling = 1.f; - params.llama3_inv_scaling_factor = 0; - params.yarn_ramp_inv_factor_div_2 = 0; - Reference reference(kDump ? Reference::kUNFUSED : Reference::kFLASH_ATTENTION, {}); // Reference reference(Reference::kUNFUSED, {}); reference.Reshape(kInputLen, kContextLen, kHeadNum, kHeadDim, KvHeadNum, kBatchSize); @@ -548,17 +543,6 @@ int test_attention() k_ptrs.data().get(), cu_kv_lens.data().get(), cu_block_cnts.data().get(), - nullptr, // DECODING ? nullptr : params.rope_theta, - kRoPEDim, - 1., - 0., - 0., - 1.0, - 1.0, - 0.0, - 0.0, - 0.0, - 1.0, KvHeadNum * kContextLen, 0, kContextLen, diff --git a/src/turbomind/models/llama/CMakeLists.txt b/src/turbomind/models/llama/CMakeLists.txt index 3c714bd23..72444d33d 100644 --- a/src/turbomind/models/llama/CMakeLists.txt +++ b/src/turbomind/models/llama/CMakeLists.txt @@ -20,8 +20,11 @@ add_library(Llama STATIC unified_attention_layer.cc llama_kernels.cu llama_decoder_kernels.cu + rotary_emb.cu llama_utils.cu - mla_utils.cu) + mla_utils.cu +) + set_property(TARGET Llama PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET Llama PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) target_link_libraries(Llama PUBLIC CUDA::cudart diff --git a/src/turbomind/models/llama/LlamaBatch.cc b/src/turbomind/models/llama/LlamaBatch.cc index ea321d06a..9ef669f8d 100644 --- a/src/turbomind/models/llama/LlamaBatch.cc +++ b/src/turbomind/models/llama/LlamaBatch.cc @@ -368,15 +368,15 @@ void LlamaBatch::ProcessInferRequests(const Requests& requests) // compute rope scaling factor if (r->start_flag) { - seq.rope_theta = model_->attn_param_.rotary_embedding_base; - if (model_->attn_param_.use_dynamic_ntk) { - auto scaling_factor = model_->attn_param_.rope_scaling_factor; + seq.rope_theta = model_->attn_param_.rope.base; + if (model_->attn_param_.rope.type == RopeType::kDynamic) { + auto scaling_factor = model_->attn_param_.rope.factor; if (scaling_factor >= 1.f) { // infer by current context length auto max_seq_len = state.h_context_length[idx]; - auto max_pos_emb = model_->attn_param_.max_position_embeddings; + auto max_pos_emb = model_->attn_param_.rope.max_position_embeddings; if (max_seq_len > max_pos_emb) { scaling_factor = scaling_factor * max_seq_len / max_pos_emb - (scaling_factor - 1); - float rope_dim = model_->attn_param_.rotary_embedding_dim; + float rope_dim = model_->attn_param_.rope.dim; seq.rope_theta *= powf(scaling_factor, rope_dim / (rope_dim - 2.f)); TM_LOG_INFO("[ProcessInferRequests] %ld rope_scaling_factor: %f, rope_theta = %f", (long)seq.id, diff --git a/src/turbomind/models/llama/llama_params.h b/src/turbomind/models/llama/llama_params.h index 0a505b11a..ff0482eef 100644 --- a/src/turbomind/models/llama/llama_params.h +++ b/src/turbomind/models/llama/llama_params.h @@ -59,22 +59,45 @@ struct MoeParam { std::vector expert_num; }; +enum class RopeType +{ + kDefault, + kLinear, + kDynamic, + kYarn, + kLlama3, +}; + +struct YarnRopeParam { + float attention_factor; + float beta_fast; + float beta_slow; +}; + +struct Llama3RopeParam { + float low_freq_factor; + float high_freq_factor; + int original_max_position_embeddings; +}; + struct AttentionParam { - int rotary_embedding_dim; - float rotary_embedding_base; - int max_position_embeddings; - float softmax_scale; - std::string rope_scaling_type; - int original_max_position_embeddings; - float rope_scaling_factor; - float low_freq_factor; - float high_freq_factor; - float attention_factor; - float beta_fast; - float beta_slow; - bool use_dynamic_ntk; - bool use_logn_attn; - int cache_block_seq_len; + float softmax_scale; + int cache_block_seq_len; + bool use_logn_attn; + // rope + struct { + // common + RopeType type; + int dim; + float base; + float factor; + int max_position_embeddings; + // special + union { + YarnRopeParam yarn; + Llama3RopeParam llama3; + }; + } rope; }; struct EngineParam { diff --git a/src/turbomind/models/llama/rotary_emb.cu b/src/turbomind/models/llama/rotary_emb.cu new file mode 100644 index 000000000..efb98bf85 --- /dev/null +++ b/src/turbomind/models/llama/rotary_emb.cu @@ -0,0 +1,233 @@ +// Copyright (c) OpenMMLab. All rights reserved. +#include "src/turbomind/models/llama/rotary_emb.h" +#include + +namespace turbomind { + +__device__ int get_batch_id(int qi, int* q_len, int batch_size) +{ + int result{}; + int end = (batch_size + blockDim.x - 1) / blockDim.x * blockDim.x; + for (int i = threadIdx.x; i < end; i += blockDim.x) { + int prefix_sum = (i < batch_size) ? q_len[i + 1] : q_len[batch_size]; + auto count = __syncthreads_count(prefix_sum > qi); + if (count != 0) { + result = i / blockDim.x * blockDim.x + blockDim.x - count + 1; + break; + } + } + return result; +} + +__inline__ __device__ float compute_default_parameters(float base, float dim, int di, float factor) +{ + float scale_factor = -log2f(base) / dim; + float inv_freq = exp2f(di * scale_factor) * factor; + return inv_freq; +} + +__global__ void computeCosSinDefault(const float* rope_base, + int* q_len, + int* k_len, + int token_num, + int batch_size, + int dim, + float factor, + float* cos_sin) +{ + int qi = blockIdx.x; + int di = threadIdx.x; + + int bid = get_batch_id(qi, q_len, batch_size); + int history_len = (k_len[bid] - k_len[bid - 1]) - (q_len[bid] - q_len[bid - 1]); + float base = rope_base[bid - 1]; + float ti = history_len + qi - q_len[bid - 1]; + + float inv_freq = compute_default_parameters(base, dim, di * 2, factor); + float c, s; + sincosf(ti * inv_freq, &s, &c); + (float2&)cos_sin[dim * qi + 2 * di] = {c, s}; +} + +__global__ void computeCosSinLlama3(const float* rope_base, + int* q_len, + int* k_len, + int token_num, + int batch_size, + int dim, + float inv_scaling_factor, + float alpha, + float beta, + float* cos_sin) +{ + int qi = blockIdx.x; + int di = threadIdx.x; + + int bid = get_batch_id(qi, q_len, batch_size); + int history_len = (k_len[bid] - k_len[bid - 1]) - (q_len[bid] - q_len[bid - 1]); + float base = rope_base[bid - 1]; + float ti = history_len + qi - q_len[bid - 1]; + + float inv_freq = compute_default_parameters(base, dim, di * 2, 1.0f); + auto smooth = fmaxf(0.f, fminf(1.f, alpha * inv_freq - beta)); + inv_freq = (1 - smooth) * inv_freq * inv_scaling_factor + smooth * inv_freq; + float c, s; + sincosf(ti * inv_freq, &s, &c); + (float2&)cos_sin[dim * qi + 2 * di] = {c, s}; +} + +__global__ void computeCosSinYarn(const float* rope_base, + int* q_len, + int* k_len, + int token_num, + int batch_size, + int dim, + float ramp_inv_factor_div_2, + float ramp_inv_factor_mul_min, + float inv_scaling_factor, + float attention_scaling, + float* cos_sin) +{ + int qi = blockIdx.x; + int di = threadIdx.x; + + int bid = get_batch_id(qi, q_len, batch_size); + int history_len = (k_len[bid] - k_len[bid - 1]) - (q_len[bid] - q_len[bid - 1]); + float base = rope_base[bid - 1]; + float ti = history_len + qi - q_len[bid - 1]; + + float inv_freq = compute_default_parameters(base, dim, di * 2, 1.0f); + float alpha = 2 * di * ramp_inv_factor_div_2 - ramp_inv_factor_mul_min; + alpha = fmaxf(0.f, fminf(1.f, alpha)); + inv_freq = inv_freq - inv_freq * alpha * inv_scaling_factor; + + float c, s; + sincosf(ti * inv_freq, &s, &c); + c *= attention_scaling; + s *= attention_scaling; + (float2&)cos_sin[dim * qi + 2 * di] = {c, s}; +} + +RopeType GetRoPEType(const std::string& type) +{ + std::map lookup = {{"", RopeType::kDefault}, + {"linear", RopeType::kLinear}, + {"dynamic", RopeType::kDynamic}, + {"yarn", RopeType::kYarn}, + {"llama3", RopeType::kLlama3}}; + return lookup.at(type); +} + +void RotaryEmbeddingV2::freeBuffer() +{ + allocator_->free((void**)&cos_sin_); +} + +void RotaryEmbeddingV2::allocateBuffer(size_t token_num) +{ + cos_sin_ = (float*)allocator_->reMalloc(cos_sin_, sizeof(float) * token_num * dim_); +} + +RotaryEmbeddingV2::RotaryEmbeddingV2(const AttentionParam& param, cudaStream_t stream, IAllocator* allocator): + stream_(stream), allocator_(allocator) +{ + type_ = param.rope.type; + dim_ = param.rope.dim; + + switch (type_) { + case RopeType::kDefault: + break; + case RopeType::kLinear: + inv_factor_ = 1.0f / param.rope.factor; + break; + case RopeType::kDynamic: + inv_factor_ = param.rope.factor; + break; + case RopeType::kYarn: { + const double PI = 3.14159265358979323846; + auto find_correction_dim = [&](float num_rotations) { + return (param.rope.dim * std::log(param.rope.max_position_embeddings / (num_rotations * 2 * PI))) + / (2 * std::log(param.rope.base)); + }; + auto find_correction_range = [&](float low_rot, float high_rot, float& low, float& high) { + low = std::floor(find_correction_dim(low_rot)); + high = std::ceil(find_correction_dim(high_rot)); + low = std::max(low, 0.f); + high = std::min(high, param.rope.dim - 1.f); + }; + float low, high; + find_correction_range(param.rope.yarn.beta_fast, param.rope.yarn.beta_slow, low, high); + if (low == high) { + high += 0.01f; + } + yarn_.ramp_inv_factor_div_2 = 1.0 / (high - low) / 2.0; + yarn_.ramp_inv_factor_mul_min = 1.0 / (high - low) * low; + yarn_.inv_scaling_factor = (1 - 1.0 / param.rope.factor); + yarn_.attention_factor = param.rope.yarn.attention_factor; + break; + } + case RopeType::kLlama3: { + const double PI = 3.14159265358979323846; + float inv_diff_freq_factor = 1.0 / (param.rope.llama3.high_freq_factor - param.rope.llama3.low_freq_factor); + llama3_.inv_scaling_factor = 1.0 / param.rope.factor; + llama3_.alpha = param.rope.llama3.original_max_position_embeddings / (2 * PI) * inv_diff_freq_factor; + llama3_.beta = param.rope.llama3.low_freq_factor * inv_diff_freq_factor; + break; + } + default: + FT_CHECK(0); + break; + } +} + +void RotaryEmbeddingV2::forward(const RotaryEmbeddingV2Param& params) +{ + allocateBuffer(params.token_num); + + const int grid = params.token_num; + const int block = dim_ / 2; + + switch (type_) { + case RopeType::kDefault: + case RopeType::kLinear: + case RopeType::kDynamic: + computeCosSinDefault<<>>(params.rope_theta, + params.q_len, + params.k_ken, + params.token_num, + params.batch_size, + dim_, + inv_factor_, + cos_sin_); + break; + case RopeType::kLlama3: + computeCosSinLlama3<<>>(params.rope_theta, + params.q_len, + params.k_ken, + params.token_num, + params.batch_size, + dim_, + llama3_.inv_scaling_factor, + llama3_.alpha, + llama3_.beta, + cos_sin_); + break; + case RopeType::kYarn: + computeCosSinYarn<<>>(params.rope_theta, + params.q_len, + params.k_ken, + params.token_num, + params.batch_size, + dim_, + yarn_.ramp_inv_factor_div_2, + yarn_.ramp_inv_factor_mul_min, + yarn_.inv_scaling_factor, + yarn_.attention_factor, + cos_sin_); + break; + default: + FT_CHECK(0); + } +} + +} // namespace turbomind diff --git a/src/turbomind/models/llama/rotary_emb.h b/src/turbomind/models/llama/rotary_emb.h new file mode 100644 index 000000000..db250402a --- /dev/null +++ b/src/turbomind/models/llama/rotary_emb.h @@ -0,0 +1,62 @@ +// Copyright (c) OpenMMLab. All rights reserved. +#pragma once +#include "src/turbomind/models/llama/llama_params.h" +#include "src/turbomind/utils/allocator.h" + +namespace turbomind { + +RopeType GetRoPEType(const std::string& type); + +struct RotaryEmbeddingV2Param { + float* rope_theta; + int* q_len; + int* k_ken; + int batch_size; + int token_num; +}; + +struct InnerYarnRopeParam { + float attention_factor; + float ramp_inv_factor_div_2; + float ramp_inv_factor_mul_min; + float inv_scaling_factor; +}; + +struct InnerLlama3RopeParam { + float inv_scaling_factor; + float alpha; + float beta; +}; + +struct RotaryEmbeddingV2 { + + RotaryEmbeddingV2(const AttentionParam& param, cudaStream_t stream, IAllocator* allocator); + + void freeBuffer(); + + void allocateBuffer(size_t token_num); + + ~RotaryEmbeddingV2() + { + freeBuffer(); + } + + void forward(const RotaryEmbeddingV2Param& params); + + cudaStream_t const stream_; + IAllocator* const allocator_; + + int dim_; + RopeType type_; + float inv_factor_{1.0}; + + union { + InnerYarnRopeParam yarn_; + InnerLlama3RopeParam llama3_; + }; + + // output + float* cos_sin_; // num_token x dim, (cos, sin, ...) +}; + +}; // namespace turbomind diff --git a/src/turbomind/models/llama/unified_attention_layer.cc b/src/turbomind/models/llama/unified_attention_layer.cc index 7a6eddc4b..77d53afd5 100644 --- a/src/turbomind/models/llama/unified_attention_layer.cc +++ b/src/turbomind/models/llama/unified_attention_layer.cc @@ -187,6 +187,8 @@ inline void UnifiedAttentionLayer::forward(TensorMap* outputs, const TensorMa bool* is_finished = inputs->getPtr("finished"); float* rope_theta = inputs->getPtr("rope_theta"); + float* cos_sin = inputs->at("cos_sin", Tensor{MEMORY_GPU, TYPE_INVALID, {}, nullptr}).getPtr(); + void** block_ptrs = outputs->getPtr("block_ptrs"); int* cu_block_count = inputs->getPtr("cu_block_counts"); @@ -310,53 +312,11 @@ inline void UnifiedAttentionLayer::forward(TensorMap* outputs, const TensorMa params.inv_sqrt_dh /= std::sqrt((float)params.size_per_head); } - params.rotary_embedding_dim = param_.rotary_embedding_dim; - params.rotary_embedding_base = param_.rotary_embedding_base; - params.max_position_embeddings = param_.max_position_embeddings; - params.rope_scaling_factor = param_.rope_scaling_factor; - params.attention_scaling = 1.0; - params.rope_ti_scale = 1.f; - if (param_.rope_scaling_type == "linear") { - params.rope_ti_scale /= param_.rope_scaling_factor; - } - if (param_.rope_scaling_type == "llama3") { - const double PI = 3.14159265358979323846; - float inv_diff_freq_factor = 1.0 / (param_.high_freq_factor - param_.low_freq_factor); - params.llama3_inv_scaling_factor = 1.0 / param_.rope_scaling_factor; - params.llama3_alpha = param_.original_max_position_embeddings / (2 * PI) * inv_diff_freq_factor; - params.llama3_beta = param_.low_freq_factor * inv_diff_freq_factor; - } - if (param_.rope_scaling_type == "yarn") { - const double PI = 3.14159265358979323846; - auto find_correction_dim = [&](float num_rotations) { - return (param_.rotary_embedding_dim - * std::log(param_.max_position_embeddings / (num_rotations * 2 * PI))) - / (2 * std::log(param_.rotary_embedding_base)); - }; - auto find_correction_range = [&](float low_rot, float high_rot, float& low, float& high) { - low = std::floor(find_correction_dim(low_rot)); - high = std::ceil(find_correction_dim(high_rot)); - low = std::max(low, 0.f); - high = std::min(high, param_.rotary_embedding_dim - 1.f); - }; - float low, high; - find_correction_range(param_.beta_fast, param_.beta_slow, low, high); - // https://github.com/huggingface/transformers/blob/6c3f168b36882f0beebaa9121eafa1928ba29633/src/transformers/modeling_rope_utils.py#L216 - if (low == high) { - high += 0.001f; - } - params.yarn_ramp_inv_factor_div_2 = 1.0 / (high - low) / 2.0; - params.yarn_ramp_inv_factor_mul_min = 1.0 / (high - low) * low; - params.yarn_inv_scaling_factor = (1 - 1.0 / param_.rope_scaling_factor); - if (param_.attention_factor < 0) { - params.attention_scaling = 0.1 * std::log(param_.rope_scaling_factor) + 1.0; - } - else { - params.attention_scaling = param_.attention_factor; - } - } - - params.use_logn_attn = param_.use_logn_attn; + // rope + params.rotary_embedding_dim = param_.rope.dim; + params.max_position_embeddings = param_.rope.max_position_embeddings; + params.cos_sin = cos_sin; + params.use_logn_attn = param_.use_logn_attn; // Decoding use only for now FT_CHECK(barriers_); diff --git a/src/turbomind/models/llama/unified_decoder.cc b/src/turbomind/models/llama/unified_decoder.cc index ec0e75b7e..792c4d3b5 100644 --- a/src/turbomind/models/llama/unified_decoder.cc +++ b/src/turbomind/models/llama/unified_decoder.cc @@ -42,6 +42,8 @@ UnifiedDecoder::UnifiedDecoder(const ModelParam& model, ffn_layer_ = std::make_unique>(model, tp, ctx); } + rotary_emb_ = std::make_unique(attn, ctx.stream, ctx.allocator.get()); + check_cuda_error(cudaEventCreateWithFlags(&ev_h_cu_x_, cudaEventDisableTiming)); } @@ -86,6 +88,7 @@ void UnifiedDecoder::forwardSelfAttn(T* attn_io, inputs.insert("cu_k_len", {MEMORY_GPU, TYPE_INT32, {batch_size + 1}, cu_k_len_}); inputs.insert("h_cu_q_len", {MEMORY_CPU, TYPE_INT32, {batch_size + 1}, h_cu_q_len_}); inputs.insert("h_cu_k_len", {MEMORY_CPU, TYPE_INT32, {batch_size + 1}, h_cu_k_len_}); + inputs.insert("cos_sin", {MEMORY_GPU, TYPE_FP32, {token_num, (size_t)rotary_emb_->dim_}, rotary_emb_->cos_sin_}); TensorMap outputs(*_outputs); outputs.insert("hidden_features", {MEMORY_GPU, dtype_, {token_num, hidden_units_}, attn_io}); @@ -160,6 +163,16 @@ void UnifiedDecoder::forward(TensorMap* outputs, const TensorMap* inputs, con count_and_fix(decoder_output, token_num * hidden_units_, Concat("norm0", 0), 2); + { + RotaryEmbeddingV2Param params; + params.rope_theta = inputs->getPtr("rope_theta"); + params.q_len = cu_q_len_; + params.k_ken = cu_k_len_; + params.batch_size = batch_size; + params.token_num = token_num; + rotary_emb_->forward(params); + } + for (size_t layer = 0; layer < layer_num_; ++layer) { /// TODO: do not skip the layers when they are heterogeneous diff --git a/src/turbomind/models/llama/unified_decoder.h b/src/turbomind/models/llama/unified_decoder.h index e08567136..48bac529f 100644 --- a/src/turbomind/models/llama/unified_decoder.h +++ b/src/turbomind/models/llama/unified_decoder.h @@ -5,6 +5,7 @@ #include "src/turbomind/models/llama/context.h" #include "src/turbomind/models/llama/llama_params.h" #include "src/turbomind/models/llama/moe_ffn_layer.h" +#include "src/turbomind/models/llama/rotary_emb.h" #include "src/turbomind/models/llama/unified_attention_layer.h" #include "src/turbomind/utils/cublasMMWrapper.h" #include "src/turbomind/utils/cuda_utils.h" @@ -36,6 +37,7 @@ class UnifiedDecoder { std::unique_ptr> attn_layer_; std::unique_ptr> ffn_layer_; std::unique_ptr> moe_ffn_layer_; + std::unique_ptr rotary_emb_; cudaEvent_t ev_h_cu_x_{}; diff --git a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc index 40c5ac890..f95ee8410 100644 --- a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc +++ b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc @@ -140,10 +140,10 @@ void LlamaTritonModel::handleMissingParams() (int)model_param_.vocab_size); } - if (!attn_param_.max_position_embeddings) { - attn_param_.max_position_embeddings = 2048; + if (!attn_param_.rope.max_position_embeddings) { + attn_param_.rope.max_position_embeddings = 2048; TM_LOG_WARNING("[LlamaTritonModel] `max_position_embeddings` is not set, default to %d.", - (int)attn_param_.max_position_embeddings); + (int)attn_param_.rope.max_position_embeddings); } if (!engine_param_.max_batch_size) { @@ -153,7 +153,7 @@ void LlamaTritonModel::handleMissingParams() } if (!engine_param_.session_len) { - engine_param_.session_len = attn_param_.max_position_embeddings; + engine_param_.session_len = attn_param_.rope.max_position_embeddings; TM_LOG_WARNING("[LlamaTritonModel] `session_len` is not set, default to %d.", (int)engine_param_.session_len); } @@ -277,22 +277,25 @@ LlamaTritonModel::LlamaTritonModel(size_t tensor_para_size, model_param_.attn_bias = model_reader["attn_bias"].as(0); model_param_.group_size = model_reader["group_size"].as(0); + attn_param_.softmax_scale = attention_reader["softmax_scale"].as(0); + attn_param_.use_logn_attn = attention_reader["use_logn_attn"].as(0); // rotary embedding parameters - attn_param_.rotary_embedding_dim = attention_reader["rotary_embedding"].as(); - attn_param_.rotary_embedding_base = attention_reader["rope_theta"].as(10000.0f); - attn_param_.softmax_scale = attention_reader["softmax_scale"].as(0); - attn_param_.attention_factor = attention_reader["attention_factor"].as(-1.f); - attn_param_.beta_fast = attention_reader["beta_fast"].as(32.f); - attn_param_.beta_slow = attention_reader["beta_slow"].as(1.f); - attn_param_.rope_scaling_type = attention_reader["rope_scaling_type"].as(""); - attn_param_.rope_scaling_factor = attention_reader["rope_scaling_factor"].as(0.f); - attn_param_.low_freq_factor = attention_reader["low_freq_factor"].as(1.0); - attn_param_.high_freq_factor = attention_reader["high_freq_factor"].as(1.0); - attn_param_.max_position_embeddings = attention_reader["max_position_embeddings"].as(0); - attn_param_.use_dynamic_ntk = attention_reader["use_dynamic_ntk"].as(0); - attn_param_.use_logn_attn = attention_reader["use_logn_attn"].as(0); - - attn_param_.original_max_position_embeddings = attention_reader["original_max_position_embeddings"].as(0); + attn_param_.rope.type = GetRoPEType(attention_reader["rope_scaling_type"].as("")); + attn_param_.rope.dim = attention_reader["rotary_embedding"].as(); + attn_param_.rope.base = attention_reader["rope_theta"].as(10000.0f); + attn_param_.rope.max_position_embeddings = attention_reader["max_position_embeddings"].as(0); + attn_param_.rope.factor = attention_reader["rope_scaling_factor"].as(0.f); + if (attn_param_.rope.type == RopeType::kYarn) { + attn_param_.rope.yarn.attention_factor = attention_reader["attention_factor"].as(-1.f); + attn_param_.rope.yarn.beta_fast = attention_reader["beta_fast"].as(32.f); + attn_param_.rope.yarn.beta_slow = attention_reader["beta_slow"].as(1.f); + } + else if (attn_param_.rope.type == RopeType::kLlama3) { + attn_param_.rope.llama3.low_freq_factor = attention_reader["low_freq_factor"].as(1.0); + attn_param_.rope.llama3.high_freq_factor = attention_reader["high_freq_factor"].as(1.0); + attn_param_.rope.llama3.original_max_position_embeddings = + attention_reader["original_max_position_embeddings"].as(0); + } engine_param_.max_batch_size = engine_reader["max_batch_size"].as(0); engine_param_.max_prefill_token_num = engine_reader["max_prefill_token_num"].as(0);