From 05d011c0733149452a1e75b3c568921b93a74598 Mon Sep 17 00:00:00 2001 From: irexyc Date: Mon, 25 Nov 2024 02:41:04 +0000 Subject: [PATCH 1/5] use precomputed cos sin --- .../kernels/attention/attention_params.h | 1 + .../kernels/attention/attention_universal.h | 29 +++ .../kernels/attention/kv_cache_utils_v2.cu | 34 +++ .../kernels/attention/kv_cache_utils_v2.h | 2 + .../kernels/attention/rotary_embedding.h | 15 ++ .../kernels/attention/test_attention.cu | 1 + src/turbomind/models/llama/CMakeLists.txt | 1 + src/turbomind/models/llama/rotary_emb.cu | 226 ++++++++++++++++++ src/turbomind/models/llama/rotary_emb.h | 65 +++++ .../models/llama/unified_attention_layer.cc | 4 + src/turbomind/models/llama/unified_decoder.cc | 16 ++ src/turbomind/models/llama/unified_decoder.h | 2 + 12 files changed, 396 insertions(+) create mode 100644 src/turbomind/models/llama/rotary_emb.cu create mode 100644 src/turbomind/models/llama/rotary_emb.h diff --git a/src/turbomind/kernels/attention/attention_params.h b/src/turbomind/kernels/attention/attention_params.h index b6dfaa596..f0f2e1af4 100644 --- a/src/turbomind/kernels/attention/attention_params.h +++ b/src/turbomind/kernels/attention/attention_params.h @@ -56,6 +56,7 @@ struct AttentionParams { int size_per_head; float inv_sqrt_dh; + float* cos_sin; // rotary embedding int rotary_embedding_dim; float rotary_embedding_base; diff --git a/src/turbomind/kernels/attention/attention_universal.h b/src/turbomind/kernels/attention/attention_universal.h index 5fb583bd1..5fbeb7581 100644 --- a/src/turbomind/kernels/attention/attention_universal.h +++ b/src/turbomind/kernels/attention/attention_universal.h @@ -194,6 +194,8 @@ struct AttentionUniversal { Vec vec_K[1][ITER_C]; Vec vec_V[1][ITER_C]; + Array vec_cs[ITER_S][ITER_C]; // precomputed cos sin + const int2 offset = Map::get_offset(warp_id, lane_id); // Load Q @@ -217,12 +219,38 @@ struct AttentionUniversal { Ldg(vec_V[0][c], ¶ms.v[k_idx]); } } + if (params.cos_sin) { + float* cos_sin = params.cos_sin; + const int64_t index = qi * kHeadDim + di; + PRAGMA_UNROLL + for (int k = 0; k < kVecSize; k += 4) { + (float4&)vec_cs[s][c][k] = __ldg((const float4*)&cos_sin[index + k]); + } + } } } } ApplyBias(vec_Q, vec_K, vec_V, params, head_idx, kv_head_idx, offset); + if (params.cos_sin) { + PrecomputeFastRoPE rope{}; + PRAGMA_UNROLL + for (int c = 0; c < ITER_C; ++c) { + const int di = offset.x + c * Map::kDeltaC; + PRAGMA_UNROLL + for (int s = 0; s < ITER_S; ++s) { + rope.apply(vec_Q[s][c], vec_cs[s][c]); + if constexpr (kProcessKV) { + if (s == 0) { + rope.apply(vec_K[0][c], vec_cs[s][c]); + } + } + } + } + } + +#if 0 const float rope_base = params.rope_theta ? params.rope_theta[batch_idx] : params.rotary_embedding_base; PRAGMA_UNROLL for (int c = 0; c < ITER_C; ++c) { @@ -251,6 +279,7 @@ struct AttentionUniversal { } } } +#endif if (params.use_logn_attn) { PRAGMA_UNROLL diff --git a/src/turbomind/kernels/attention/kv_cache_utils_v2.cu b/src/turbomind/kernels/attention/kv_cache_utils_v2.cu index 20bb00fde..d552a6380 100644 --- a/src/turbomind/kernels/attention/kv_cache_utils_v2.cu +++ b/src/turbomind/kernels/attention/kv_cache_utils_v2.cu @@ -20,6 +20,7 @@ __global__ void __launch_bounds__(128) ProcessKV_v2(char** blocks, const int* cu_q_len, const int* cu_k_len, const int* cu_block_num, + const float* cos_sin, const float* rope_base, int rope_dim, float rope_ti_scale, @@ -124,6 +125,35 @@ __global__ void __launch_bounds__(128) ProcessKV_v2(char** blocks, } } + if (cos_sin) { + Array vec_cs[ITER_S][ITER_C]; + PRAGMA_UNROLL + for (int s = 0; s < ITER_S; ++s) { + const int qi = offset.y + s * Map::kDeltaS + token_idx; + PRAGMA_UNROLL + for (int c = 0; c < ITER_C; ++c) { + const int di = offset.x + c * Map::kDeltaC; + const int64_t index = (qi_beg + qi) * HeadDim + di; + PRAGMA_UNROLL + for (int k = 0; k < kVecSize; k += 4) { + if (qi < q_len) { + (float4&)vec_cs[s][c][k] = __ldg((const float4*)&cos_sin[index + k]); + } + } + } + } + + PrecomputeFastRoPE rope; + PRAGMA_UNROLL + for (int s = 0; s < ITER_S; ++s) { + PRAGMA_UNROLL + for (int c = 0; c < ITER_C; ++c) { + rope.apply(vec_K[s][c], vec_cs[s][c]); + } + } + } + +#if 0 if (rope_base) { float base = rope_base[batch_idx]; PRAGMA_UNROLL @@ -149,6 +179,7 @@ __global__ void __launch_bounds__(128) ProcessKV_v2(char** blocks, } } } +#endif Array param_K[ITER_S]; Array param_V[ITER_S]; @@ -211,6 +242,7 @@ void invokeProcessKV_v2(char** blocks, const int* cu_q_len, const int* cu_k_len, const int* cu_block_num, + const float* cos_sin, const float* rope_base, int rope_dim, float rope_ti_scale, @@ -257,6 +289,7 @@ void invokeProcessKV_v2(char** blocks, cu_q_len, cu_k_len, cu_block_num, + cos_sin, rope_base, rope_dim, rope_ti_scale, @@ -306,6 +339,7 @@ void invokeProcessKV_v2(char** blocks, const int* cu_q_len, \ const int* cu_k_len, \ const int* cu_block_num, \ + const float* cos_sin, \ const float* rope_base, \ int rope_dim, \ float rope_ti_scale, \ diff --git a/src/turbomind/kernels/attention/kv_cache_utils_v2.h b/src/turbomind/kernels/attention/kv_cache_utils_v2.h index fe45ad7be..408310ba9 100644 --- a/src/turbomind/kernels/attention/kv_cache_utils_v2.h +++ b/src/turbomind/kernels/attention/kv_cache_utils_v2.h @@ -16,6 +16,7 @@ void invokeProcessKV_v2(char** blocks, const int* cu_q_len, const int* cu_k_len, const int* cu_block_num, + const float* cos_sin, const float* rope_base, int rope_dim, float rope_ti_scale, @@ -51,6 +52,7 @@ void invokeProcessKV_v2_(const AttentionParams& params) params.cu_q_len, params.cu_k_len, params.block_iter_params.cu_block_nums, + params.cos_sin, params.rope_theta, params.rotary_embedding_dim, params.rope_ti_scale, diff --git a/src/turbomind/kernels/attention/rotary_embedding.h b/src/turbomind/kernels/attention/rotary_embedding.h index 8e09da22c..849027508 100644 --- a/src/turbomind/kernels/attention/rotary_embedding.h +++ b/src/turbomind/kernels/attention/rotary_embedding.h @@ -67,6 +67,21 @@ __device__ void ApplyRotaryEmbedding(Array& x, float base, int dims, int t } } +struct PrecomputeFastRoPE { + + template + __device__ void apply(Array& x, Array& cs) + { + PRAGMA_UNROLL + for (int i = 0; i < N; i += 2) { + float tmp0 = cs[i] * (float)x[i] - cs[i + 1] * (float)x[i + 1]; + float tmp1 = cs[i] * (float)x[i + 1] + cs[i + 1] * (float)x[i]; + x[i] = (T)tmp0; + x[i + 1] = (T)tmp1; + } + } +}; + template struct FastRoPE { diff --git a/src/turbomind/kernels/attention/test_attention.cu b/src/turbomind/kernels/attention/test_attention.cu index c6d7b4063..1d8e51103 100644 --- a/src/turbomind/kernels/attention/test_attention.cu +++ b/src/turbomind/kernels/attention/test_attention.cu @@ -147,6 +147,7 @@ void TestBlocks(const thrust::universal_vector& k_cache, // [B, H, S, cu_seq_lens.data().get(), cu_block_cnts.data().get(), nullptr, + nullptr, rope_dim, 1., 0., diff --git a/src/turbomind/models/llama/CMakeLists.txt b/src/turbomind/models/llama/CMakeLists.txt index 285fcea31..10d3d36b0 100644 --- a/src/turbomind/models/llama/CMakeLists.txt +++ b/src/turbomind/models/llama/CMakeLists.txt @@ -20,6 +20,7 @@ add_library(Llama STATIC unified_attention_layer.cc llama_kernels.cu llama_decoder_kernels.cu + rotary_emb.cu llama_utils.cu) set_property(TARGET Llama PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET Llama PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) diff --git a/src/turbomind/models/llama/rotary_emb.cu b/src/turbomind/models/llama/rotary_emb.cu new file mode 100644 index 000000000..2ecec40a7 --- /dev/null +++ b/src/turbomind/models/llama/rotary_emb.cu @@ -0,0 +1,226 @@ +// Copyright (c) OpenMMLab. All rights reserved. +#include "src/turbomind/models/llama/rotary_emb.h" +#include + +namespace turbomind { + +__device__ int get_batch_id(int qi, int* q_len, int batch_size) +{ + int result{}; + int end = (batch_size + blockDim.x - 1) / blockDim.x * blockDim.x; + for (int i = threadIdx.x; i < end; i += blockDim.x) { + int prefix_sum = (i < batch_size) ? q_len[i + 1] : q_len[batch_size]; + auto count = __syncthreads_count(prefix_sum > qi); + if (count != 0) { + result = i / blockDim.x * blockDim.x + blockDim.x - count + 1; + break; + } + } + return result; +} + +__inline__ __device__ float compute_default_parameters(float base, float dim, int di, float factor) +{ + float scale_factor = -log2f(base) / dim; + float inv_freq = exp2f(di * scale_factor) * factor; + return inv_freq; +} + +__global__ void computeCosSinDefault(const float* rope_base, + int* q_len, + int* k_len, + int token_num, + int batch_size, + int dim, + float factor, + float* cos_sin) +{ + int qi = blockIdx.x; + int di = threadIdx.x; + + int bid = get_batch_id(qi, q_len, batch_size); + int history_len = (k_len[bid] - k_len[bid - 1]) - (q_len[bid] - q_len[bid - 1]); + float base = rope_base[bid - 1]; + float ti = history_len + qi - q_len[bid - 1]; + + float inv_freq = compute_default_parameters(base, dim, di * 2, factor); + float c, s; + sincosf(ti * inv_freq, &s, &c); + (float2&)cos_sin[dim * qi + 2 * di] = {c, s}; +} + +__global__ void computeCosSinLlama3(const float* rope_base, + int* q_len, + int* k_len, + int token_num, + int batch_size, + int dim, + float llama3_inv_scaling_factor, + float llama3_alpha, + float llama3_beta, + float* cos_sin) +{ + int qi = blockIdx.x; + int di = threadIdx.x; + + int bid = get_batch_id(qi, q_len, batch_size); + int history_len = (k_len[bid] - k_len[bid - 1]) - (q_len[bid] - q_len[bid - 1]); + float base = rope_base[bid - 1]; + float ti = history_len + qi - q_len[bid - 1]; + + float inv_freq = compute_default_parameters(base, dim, di * 2, 1.0f); + auto smooth = fmaxf(0.f, fminf(1.f, llama3_alpha * inv_freq - llama3_beta)); + inv_freq = (1 - smooth) * inv_freq * llama3_inv_scaling_factor + smooth * inv_freq; + float c, s; + sincosf(ti * inv_freq, &s, &c); + (float2&)cos_sin[dim * qi + 2 * di] = {c, s}; +} + +__global__ void computeCosSinYarn(const float* rope_base, + int* q_len, + int* k_len, + int token_num, + int batch_size, + int dim, + float yarn_ramp_inv_factor_div_2, + float yarn_ramp_inv_factor_mul_min, + float yarn_inv_scaling_factor, + float attention_scaling, + float* cos_sin) +{ + int qi = blockIdx.x; + int di = threadIdx.x; + + int bid = get_batch_id(qi, q_len, batch_size); + int history_len = (k_len[bid] - k_len[bid - 1]) - (q_len[bid] - q_len[bid - 1]); + float base = rope_base[bid - 1]; + float ti = history_len + qi - q_len[bid - 1]; + + float inv_freq = compute_default_parameters(base, dim, di * 2, 1.0f); + float alpha = 2 * di * yarn_ramp_inv_factor_div_2 - yarn_ramp_inv_factor_mul_min; + alpha = fmaxf(0.f, fminf(1.f, alpha)); + inv_freq = inv_freq - inv_freq * alpha * yarn_inv_scaling_factor; + + float c, s; + sincosf(ti * inv_freq, &s, &c); + c *= attention_scaling; + s *= attention_scaling; + (float2&)cos_sin[dim * qi + 2 * di] = {c, s}; +} + +RotaryScalingType GetRoPEType(const std::string& type) +{ + std::map lookup = {{"", RotaryScalingType::kDefault}, + {"linear", RotaryScalingType::kLinear}, + {"dynamic", RotaryScalingType::kDynamic}, + {"yarn", RotaryScalingType::kYarn}, + {"llama3", RotaryScalingType::kLlama3}, + {"mrope", RotaryScalingType::kMrope}}; + return lookup.at(type); +} + +void RotaryEmbeddingV2::freeBuffer() +{ + allocator_->free((void**)&cos_sin_); +} + +void RotaryEmbeddingV2::allocateBuffer(size_t token_num) +{ + cos_sin_ = (float*)allocator_->reMalloc(cos_sin_, sizeof(float) * token_num * dim_); +} + +RotaryEmbeddingV2::RotaryEmbeddingV2(const AttentionParam& param, cudaStream_t stream, IAllocator* allocator): + stream_(stream), allocator_(allocator) +{ + type_ = GetRoPEType(param.rope_scaling_type); + dim_ = param.rotary_embedding_dim; + rope_scaling_factor_ = 1.0f; + attention_factor_ = 1.0f; + + if (type_ == RotaryScalingType::kLinear) { + rope_scaling_factor_ /= param.rope_scaling_factor; + } + else if (type_ == RotaryScalingType::kLlama3) { + const double PI = 3.14159265358979323846; + float inv_diff_freq_factor = 1.0 / (param.high_freq_factor - param.low_freq_factor); + llama3_inv_scaling_factor_ = 1.0 / param.rope_scaling_factor; + llama3_alpha_ = param.original_max_position_embeddings / (2 * PI) * inv_diff_freq_factor; + llama3_beta_ = param.low_freq_factor * inv_diff_freq_factor; + } + else if (type_ == RotaryScalingType::kYarn) { + const double PI = 3.14159265358979323846; + auto find_correction_dim = [&](float num_rotations) { + return (param.rotary_embedding_dim * std::log(param.max_position_embeddings / (num_rotations * 2 * PI))) + / (2 * std::log(param.rotary_embedding_base)); + }; + auto find_correction_range = [&](float low_rot, float high_rot, float& low, float& high) { + low = std::floor(find_correction_dim(low_rot)); + high = std::ceil(find_correction_dim(high_rot)); + low = std::max(low, 0.f); + high = std::min(high, param.rotary_embedding_dim - 1.f); + }; + float low, high; + find_correction_range(param.beta_fast, param.beta_slow, low, high); + if (low == high) { + high += 0.01f; + } + yarn_ramp_inv_factor_div_2_ = 1.0 / (high - low) / 2.0; + yarn_ramp_inv_factor_mul_min_ = 1.0 / (high - low) * low; + yarn_inv_scaling_factor_ = (1 - 1.0 / param.rope_scaling_factor); + attention_factor_ = param.attention_factor; + } +} + +void RotaryEmbeddingV2::forward(const RotaryEmbeddingV2Params& params) +{ + allocateBuffer(params.token_num); + + const int grid = params.token_num; + const int block = dim_ / 2; + + switch (type_) { + case RotaryScalingType::kDefault: + case RotaryScalingType::kLinear: + case RotaryScalingType::kDynamic: + computeCosSinDefault<<>>(params.rope_theta, + params.q_len, + params.k_ken, + params.token_num, + params.batch_size, + dim_, + rope_scaling_factor_, + cos_sin_); + break; + case RotaryScalingType::kLlama3: + computeCosSinLlama3<<>>(params.rope_theta, + params.q_len, + params.k_ken, + params.token_num, + params.batch_size, + dim_, + llama3_inv_scaling_factor_, + llama3_alpha_, + llama3_beta_, + cos_sin_); + break; + case RotaryScalingType::kYarn: + computeCosSinYarn<<>>(params.rope_theta, + params.q_len, + params.k_ken, + params.token_num, + params.batch_size, + dim_, + yarn_ramp_inv_factor_div_2_, + yarn_ramp_inv_factor_mul_min_, + yarn_inv_scaling_factor_, + attention_factor_, + cos_sin_); + break; + case RotaryScalingType::kMrope: + FT_CHECK(0); + default: + FT_CHECK(0); + } +} + +} // namespace turbomind diff --git a/src/turbomind/models/llama/rotary_emb.h b/src/turbomind/models/llama/rotary_emb.h new file mode 100644 index 000000000..ffe81752e --- /dev/null +++ b/src/turbomind/models/llama/rotary_emb.h @@ -0,0 +1,65 @@ +// Copyright (c) OpenMMLab. All rights reserved. +#pragma once +#include "src/turbomind/models/llama/llama_params.h" +#include "src/turbomind/utils/allocator.h" + +namespace turbomind { + +enum class RotaryScalingType +{ + kDefault, + kLinear, + kDynamic, + kYarn, + kLlama3, + kMrope +}; + +struct RotaryEmbeddingV2Params { + float* rope_theta; + int* q_len; + int* k_ken; + int batch_size; + int token_num; +}; + +struct RotaryEmbeddingV2 { + + RotaryEmbeddingV2(const AttentionParam& param, cudaStream_t stream, IAllocator* allocator); + + void freeBuffer(); + + void allocateBuffer(size_t token_num); + + ~RotaryEmbeddingV2() + { + freeBuffer(); + } + + void forward(const RotaryEmbeddingV2Params& params); + + RotaryScalingType type_; + cudaStream_t const stream_; + IAllocator* const allocator_; + + // output + float* cos_sin_; // num_token x dim, (cos, sin, ...) + + int dim_; + // default, linear, dynamic + float attention_factor_; + float rope_scaling_factor_; + float inv_scale_factor_; + // llama3 + float llama3_inv_scaling_factor_; + float llama3_alpha_; + float llama3_beta_; + // yarn + float yarn_ramp_inv_factor_div_2_; + float yarn_ramp_inv_factor_mul_min_; + float yarn_inv_scaling_factor_; + // mrope + int3 mrope_section_; +}; + +}; // namespace turbomind diff --git a/src/turbomind/models/llama/unified_attention_layer.cc b/src/turbomind/models/llama/unified_attention_layer.cc index 2f99b0c2c..c80c1cb6b 100644 --- a/src/turbomind/models/llama/unified_attention_layer.cc +++ b/src/turbomind/models/llama/unified_attention_layer.cc @@ -187,6 +187,8 @@ inline void UnifiedAttentionLayer::forward(TensorMap* outputs, const TensorMa bool* is_finished = inputs->getPtr("finished"); float* rope_theta = inputs->getPtr("rope_theta"); + float* cos_sin = inputs->at("cos_sin", Tensor{MEMORY_GPU, TYPE_INVALID, {}, nullptr}).getPtr(); + void** block_ptrs = outputs->getPtr("block_ptrs"); int* cu_block_count = inputs->getPtr("cu_block_counts"); @@ -338,6 +340,8 @@ inline void UnifiedAttentionLayer::forward(TensorMap* outputs, const TensorMa } } + params.cos_sin = cos_sin; + params.use_logn_attn = param_.use_logn_attn; // Decoding use only for now diff --git a/src/turbomind/models/llama/unified_decoder.cc b/src/turbomind/models/llama/unified_decoder.cc index 68392215f..cc3774387 100644 --- a/src/turbomind/models/llama/unified_decoder.cc +++ b/src/turbomind/models/llama/unified_decoder.cc @@ -29,6 +29,7 @@ UnifiedDecoder::UnifiedDecoder(const ModelParam& model, attn_layer_ = std::make_unique>(model, attn, lora, tp, ctx); ffn_layer_ = std::make_unique>(model, tp, ctx, true); moe_ffn_layer_ = std::make_unique>(model, moe, tp, ctx); + rotary_emb_ = std::make_unique(attn, ctx.stream, ctx.allocator.get()); check_cuda_error(cudaEventCreateWithFlags(&ev_h_cu_x_, cudaEventDisableTiming)); } @@ -75,6 +76,11 @@ void UnifiedDecoder::forwardSelfAttn(T* attn_io, inputs.insert("h_cu_q_len", {MEMORY_CPU, TYPE_INT32, {batch_size + 1}, h_cu_q_len_}); inputs.insert("h_cu_k_len", {MEMORY_CPU, TYPE_INT32, {batch_size + 1}, h_cu_k_len_}); + if (rotary_emb_) { + inputs.insert("cos_sin", + {MEMORY_GPU, TYPE_FP32, {token_num, (size_t)rotary_emb_->dim_}, rotary_emb_->cos_sin_}); + } + TensorMap outputs(*_outputs); outputs.insert("hidden_features", {MEMORY_GPU, dtype_, {token_num, hidden_units_}, attn_io}); @@ -152,6 +158,16 @@ void UnifiedDecoder::forward(TensorMap* outputs, const TensorMap* inputs, con count_and_fix(decoder_output, token_num * hidden_units_, Concat("norm0", 0), 2); + { + RotaryEmbeddingV2Params params; + params.rope_theta = inputs->getPtr("rope_theta"); + params.q_len = cu_q_len_; + params.k_ken = cu_k_len_; + params.batch_size = batch_size; + params.token_num = token_num; + rotary_emb_->forward(params); + } + for (size_t layer = 0; layer < layer_num_; ++layer) { /// TODO: do not skip the layers when they are heterogeneous diff --git a/src/turbomind/models/llama/unified_decoder.h b/src/turbomind/models/llama/unified_decoder.h index f13b4ba84..e18da41c3 100644 --- a/src/turbomind/models/llama/unified_decoder.h +++ b/src/turbomind/models/llama/unified_decoder.h @@ -5,6 +5,7 @@ #include "src/turbomind/models/llama/context.h" #include "src/turbomind/models/llama/llama_params.h" #include "src/turbomind/models/llama/moe_ffn_layer.h" +#include "src/turbomind/models/llama/rotary_emb.h" #include "src/turbomind/models/llama/unified_attention_layer.h" #include "src/turbomind/utils/cublasMMWrapper.h" #include "src/turbomind/utils/cuda_utils.h" @@ -34,6 +35,7 @@ class UnifiedDecoder { std::unique_ptr> attn_layer_; std::unique_ptr> ffn_layer_; std::unique_ptr> moe_ffn_layer_; + std::unique_ptr rotary_emb_; cudaEvent_t ev_h_cu_x_{}; From 7b74b7235f6402bdaa18816065dd2d2ce1be260a Mon Sep 17 00:00:00 2001 From: irexyc Date: Mon, 25 Nov 2024 03:27:04 +0000 Subject: [PATCH 2/5] remove unused --- .../kernels/attention/attention_params.h | 19 +- .../kernels/attention/attention_universal.h | 49 ++--- .../kernels/attention/kv_cache_utils_v2.cu | 167 ++---------------- .../kernels/attention/kv_cache_utils_v2.h | 42 ----- .../kernels/attention/test_attention.cu | 36 +--- .../models/llama/unified_attention_layer.cc | 49 +---- 6 files changed, 33 insertions(+), 329 deletions(-) diff --git a/src/turbomind/kernels/attention/attention_params.h b/src/turbomind/kernels/attention/attention_params.h index f0f2e1af4..7bc770ab6 100644 --- a/src/turbomind/kernels/attention/attention_params.h +++ b/src/turbomind/kernels/attention/attention_params.h @@ -56,23 +56,10 @@ struct AttentionParams { int size_per_head; float inv_sqrt_dh; - float* cos_sin; // rotary embedding - int rotary_embedding_dim; - float rotary_embedding_base; - float rope_scaling_factor; - float attention_scaling; - int max_position_embeddings; - float rope_ti_scale; // used for linear RoPE scaling - // the following 3 parameters are used by llama3 - float llama3_inv_scaling_factor; - float llama3_alpha; - float llama3_beta; - // the following are use by yarn - float yarn_ramp_inv_factor_div_2; - float yarn_ramp_inv_factor_mul_min; - float yarn_inv_scaling_factor; - + float* cos_sin; + int rotary_embedding_dim; + int max_position_embeddings; // log(n) attention bool use_logn_attn; diff --git a/src/turbomind/kernels/attention/attention_universal.h b/src/turbomind/kernels/attention/attention_universal.h index 5fbeb7581..67b72f200 100644 --- a/src/turbomind/kernels/attention/attention_universal.h +++ b/src/turbomind/kernels/attention/attention_universal.h @@ -221,7 +221,7 @@ struct AttentionUniversal { } if (params.cos_sin) { float* cos_sin = params.cos_sin; - const int64_t index = qi * kHeadDim + di; + const int64_t index = qi * params.rotary_embedding_dim + di; PRAGMA_UNROLL for (int k = 0; k < kVecSize; k += 4) { (float4&)vec_cs[s][c][k] = __ldg((const float4*)&cos_sin[index + k]); @@ -236,51 +236,22 @@ struct AttentionUniversal { if (params.cos_sin) { PrecomputeFastRoPE rope{}; PRAGMA_UNROLL - for (int c = 0; c < ITER_C; ++c) { - const int di = offset.x + c * Map::kDeltaC; + for (int s = 0; s < ITER_S; ++s) { PRAGMA_UNROLL - for (int s = 0; s < ITER_S; ++s) { - rope.apply(vec_Q[s][c], vec_cs[s][c]); - if constexpr (kProcessKV) { - if (s == 0) { - rope.apply(vec_K[0][c], vec_cs[s][c]); + for (int c = 0; c < ITER_C; ++c) { + const int di = offset.x + c * Map::kDeltaC; + if (di < params.rotary_embedding_dim) { + rope.apply(vec_Q[s][c], vec_cs[s][c]); + if constexpr (kProcessKV) { + if (s == 0) { + rope.apply(vec_K[0][c], vec_cs[s][c]); + } } } } } } -#if 0 - const float rope_base = params.rope_theta ? params.rope_theta[batch_idx] : params.rotary_embedding_base; - PRAGMA_UNROLL - for (int c = 0; c < ITER_C; ++c) { - const int di = offset.x + c * Map::kDeltaC; - FastRoPE rope(di, - params.rotary_embedding_dim, - rope_base, - params.rope_ti_scale, - params.rope_scaling_factor, - params.llama3_inv_scaling_factor, - params.llama3_alpha, - params.llama3_beta, - params.yarn_ramp_inv_factor_div_2, - params.yarn_ramp_inv_factor_mul_min, - params.yarn_inv_scaling_factor, - params.attention_scaling, - std::integral_constant{}); - PRAGMA_UNROLL - for (int s = 0; s < ITER_S; ++s) { - const int ti = (offset.y + s * Map::kDeltaS) / CTA_H + query_idx + history_len; - rope.apply(vec_Q[s][c], ti); - if constexpr (kProcessKV) { - if (s == 0) { - rope.apply(vec_K[0][c], ti); - } - } - } - } -#endif - if (params.use_logn_attn) { PRAGMA_UNROLL for (int s = 0; s < ITER_S; ++s) { diff --git a/src/turbomind/kernels/attention/kv_cache_utils_v2.cu b/src/turbomind/kernels/attention/kv_cache_utils_v2.cu index d552a6380..5f2a5ea6e 100644 --- a/src/turbomind/kernels/attention/kv_cache_utils_v2.cu +++ b/src/turbomind/kernels/attention/kv_cache_utils_v2.cu @@ -21,17 +21,7 @@ __global__ void __launch_bounds__(128) ProcessKV_v2(char** blocks, const int* cu_k_len, const int* cu_block_num, const float* cos_sin, - const float* rope_base, int rope_dim, - float rope_ti_scale, - float rope_scaling_factor, - float llama3_inv_scaling_factor, - float llama3_alpha, - float llama3_beta, - float yarn_ramp_inv_factor_div_2, - float yarn_ramp_inv_factor_mul_min, - float yarn_inv_scaling_factor, - float attention_scaling, int64_t stride_b, int64_t stride_c, int64_t stride_h, @@ -133,7 +123,7 @@ __global__ void __launch_bounds__(128) ProcessKV_v2(char** blocks, PRAGMA_UNROLL for (int c = 0; c < ITER_C; ++c) { const int di = offset.x + c * Map::kDeltaC; - const int64_t index = (qi_beg + qi) * HeadDim + di; + const int64_t index = (qi_beg + qi) * rope_dim + di; PRAGMA_UNROLL for (int k = 0; k < kVecSize; k += 4) { if (qi < q_len) { @@ -148,38 +138,13 @@ __global__ void __launch_bounds__(128) ProcessKV_v2(char** blocks, for (int s = 0; s < ITER_S; ++s) { PRAGMA_UNROLL for (int c = 0; c < ITER_C; ++c) { - rope.apply(vec_K[s][c], vec_cs[s][c]); - } - } - } - -#if 0 - if (rope_base) { - float base = rope_base[batch_idx]; - PRAGMA_UNROLL - for (int c = 0; c < ITER_C; ++c) { - const int di = offset.x + c * Map::kDeltaC; - FastRoPE rope(di, - rope_dim, - base, - rope_ti_scale, - rope_scaling_factor, - llama3_inv_scaling_factor, - llama3_alpha, - llama3_beta, - yarn_ramp_inv_factor_div_2, - yarn_ramp_inv_factor_mul_min, - yarn_inv_scaling_factor, - attention_scaling, - std::integral_constant{}); - PRAGMA_UNROLL - for (int s = 0; s < ITER_S; ++s) { - const int ti = history_len + offset.y + s * Map::kDeltaS + token_idx; // sequence local - rope.apply(vec_K[s][c], ti); + const int di = offset.x + c * Map::kDeltaC; + if (di < rope_dim) { + rope.apply(vec_K[s][c], vec_cs[s][c]); + } } } } -#endif Array param_K[ITER_S]; Array param_V[ITER_S]; @@ -243,17 +208,7 @@ void invokeProcessKV_v2(char** blocks, const int* cu_k_len, const int* cu_block_num, const float* cos_sin, - const float* rope_base, int rope_dim, - float rope_ti_scale, - float rope_scaling_factor, - float llama3_inv_scaling_factor, - float llama3_1_alpha, - float llama3_1_beta, - float yarn_ramp_inv_factor_div_2, - float yarn_ramp_inv_factor_mul_min, - float yarn_inv_scaling_factor, - float attention_scaling, int64_t stride_b, int64_t stride_c, int64_t stride_h, @@ -290,17 +245,7 @@ void invokeProcessKV_v2(char** blocks, cu_k_len, cu_block_num, cos_sin, - rope_base, rope_dim, - rope_ti_scale, - rope_scaling_factor, - llama3_inv_scaling_factor, - llama3_1_alpha, - llama3_1_beta, - yarn_ramp_inv_factor_div_2, - yarn_ramp_inv_factor_mul_min, - yarn_inv_scaling_factor, - attention_scaling, stride_b, stride_c, stride_h, @@ -340,17 +285,7 @@ void invokeProcessKV_v2(char** blocks, const int* cu_k_len, \ const int* cu_block_num, \ const float* cos_sin, \ - const float* rope_base, \ int rope_dim, \ - float rope_ti_scale, \ - float rope_scaling_factor, \ - float llama3_inv_scaling_factor, \ - float llama3_1_alpha, \ - float llama3_1_beta, \ - float yarn_ramp_inv_factor_div_2, \ - float yarn_ramp_inv_factor_mul_min, \ - float yarn_inv_scaling_factor, \ - float attention_scaling, \ int64_t stride_b, \ int64_t stride_c, \ int64_t stride_h, \ @@ -370,28 +305,17 @@ INSTANTIATE_invokeProcessKV_v2(nv_bfloat16); #endif template -__global__ void __launch_bounds__(128) flattenKV_v2(T* k, - T* v, - const Tkv** blocks, - const int* cu_k_len, - const int* cu_block_num, - const float* rope_base, - int rope_dim, - float rope_ti_scale, - float rope_scaling_factor, - float llama3_inv_scaling_factor, - float llama3_alpha, - float llama3_beta, - float yarn_ramp_inv_factor_div_2, - float yarn_ramp_inv_factor_mul_min, - float yarn_inv_scaling_factor, - float attention_scaling, - int64_t stride_b, - int64_t stride_c, - int64_t stride_h, - int64_t stride_s, - int layer_id, - BlockLayout block_layout) +__global__ void __launch_bounds__(128) flattenKV_v2(T* k, + T* v, + const Tkv** blocks, + const int* cu_k_len, + const int* cu_block_num, + int64_t stride_b, + int64_t stride_c, + int64_t stride_h, + int64_t stride_s, + int layer_id, + BlockLayout block_layout) { constexpr int kVecSize = sizeof(uint4) / sizeof(T); @@ -462,32 +386,6 @@ __global__ void __launch_bounds__(128) flattenKV_v2(T* k, } } - if (rope_base) { - float base = rope_base[batch_idx]; - PRAGMA_UNROLL - for (int c = 0; c < ITER_C; ++c) { - const int di = offset.x + c * Map::kDeltaC; - FastRoPE rope(di, - rope_dim, - base, - rope_ti_scale, - rope_scaling_factor, - llama3_inv_scaling_factor, - llama3_alpha, - llama3_beta, - yarn_ramp_inv_factor_div_2, - yarn_ramp_inv_factor_mul_min, - yarn_inv_scaling_factor, - attention_scaling, - std::integral_constant{}); - PRAGMA_UNROLL - for (int s = 0; s < ITER_S; ++s) { - const int ti = offset.y + s * Map::kDeltaS + token_idx; // sequence local - rope.apply(out_K[s][c], ti); - } - } - } - PRAGMA_UNROLL for (int s = 0; s < ITER_S; ++s) { PRAGMA_UNROLL @@ -510,17 +408,6 @@ void invokeFlattenKV_v2(T* k, char** blocks, const int* cu_k_len, const int* cu_block_num, - const float* rope_base, - int rope_dim, - float rope_ti_scale, - float rope_scaling_factor, - float llama3_inv_scaling_factor, - float llama3_alpha, - float llama3_beta, - float yarn_ramp_inv_factor_div_2, - float yarn_ramp_inv_factor_mul_min, - float yarn_inv_scaling_factor, - float attention_scaling, int64_t stride_b, int64_t stride_c, int64_t stride_h, @@ -553,17 +440,6 @@ void invokeFlattenKV_v2(T* k, (const Tkv**)blocks, cu_k_len, cu_block_num, - rope_base, - rope_dim, - rope_ti_scale, - rope_scaling_factor, - llama3_inv_scaling_factor, - llama3_alpha, - llama3_beta, - yarn_ramp_inv_factor_div_2, - yarn_ramp_inv_factor_mul_min, - yarn_inv_scaling_factor, - attention_scaling, stride_b, stride_c, stride_h, @@ -599,17 +475,6 @@ void invokeFlattenKV_v2(T* k, char** blocks, \ const int* cu_k_len, \ const int* cu_block_num, \ - const float* rope_base, \ - int rope_dim, \ - float rope_ti_scale, \ - float rope_scaling_factor, \ - float llama3_inv_scaling_factor, \ - float llama3_alpha, \ - float llama3_beta, \ - float yarn_ramp_inv_factor_div_2, \ - float yarn_ramp_inv_factor_mul_min, \ - float yarn_inv_scaling_factor, \ - float attention_scaling, \ int64_t stride_b, \ int64_t stride_c, \ int64_t stride_h, \ diff --git a/src/turbomind/kernels/attention/kv_cache_utils_v2.h b/src/turbomind/kernels/attention/kv_cache_utils_v2.h index 408310ba9..0f789fb58 100644 --- a/src/turbomind/kernels/attention/kv_cache_utils_v2.h +++ b/src/turbomind/kernels/attention/kv_cache_utils_v2.h @@ -17,17 +17,7 @@ void invokeProcessKV_v2(char** blocks, const int* cu_k_len, const int* cu_block_num, const float* cos_sin, - const float* rope_base, int rope_dim, - float rope_ti_scale, - float rope_scaling_factor, - float llama3_inv_scaling_factor, - float llama3_alpha, - float llama3_beta, - float yarn_ramp_inv_factor_div_2, - float yarn_ramp_inv_factor_mul_min, - float yarn_inv_scaling_factor, - float attention_scaling, int64_t stride_b, int64_t stride_c, int64_t stride_h, @@ -53,17 +43,7 @@ void invokeProcessKV_v2_(const AttentionParams& params) params.cu_k_len, params.block_iter_params.cu_block_nums, params.cos_sin, - params.rope_theta, params.rotary_embedding_dim, - params.rope_ti_scale, - params.rope_scaling_factor, - params.llama3_inv_scaling_factor, - params.llama3_alpha, - params.llama3_beta, - params.yarn_ramp_inv_factor_div_2, - params.yarn_ramp_inv_factor_mul_min, - params.yarn_inv_scaling_factor, - params.attention_scaling, 0, // stride b params.stride / params.size_per_head, // stride c 1, // stride h @@ -84,17 +64,6 @@ void invokeFlattenKV_v2(T* k, char** blocks, const int* cu_k_len, const int* cu_block_num, - const float* rope_base, - int rope_dim, - float rope_ti_scale, - float rope_scaling_factor, - float llama3_inv_scaling_factor, - float llama3_alpha, - float llama3_beta, - float yarn_ramp_inv_factor_div_2, - float yarn_ramp_inv_factor_mul_min, - float yarn_inv_scaling_factor, - float attention_scaling, int64_t stride_b, int64_t stride_c, int64_t stride_h, @@ -118,17 +87,6 @@ void invokeFlattenKV_v2_(const AttentionParams& params, int sum_k_len) (char**)params.block_iter_params.block_ptrs, params.cu_k_len, params.block_iter_params.cu_block_nums, - nullptr, // params.rope_theta, - params.rotary_embedding_dim, - params.rope_ti_scale, - params.rope_scaling_factor, - params.llama3_inv_scaling_factor, - params.llama3_alpha, - params.llama3_beta, - params.yarn_ramp_inv_factor_div_2, - params.yarn_ramp_inv_factor_mul_min, - params.yarn_inv_scaling_factor, - params.attention_scaling, 0, 1, 2 * sum_k_len, diff --git a/src/turbomind/kernels/attention/test_attention.cu b/src/turbomind/kernels/attention/test_attention.cu index 1d8e51103..48012a4c0 100644 --- a/src/turbomind/kernels/attention/test_attention.cu +++ b/src/turbomind/kernels/attention/test_attention.cu @@ -147,17 +147,7 @@ void TestBlocks(const thrust::universal_vector& k_cache, // [B, H, S, cu_seq_lens.data().get(), cu_block_cnts.data().get(), nullptr, - nullptr, rope_dim, - 1., - 0., - 0., - 1.0, - 1.0, - 0.0, - 0.0, - 0.0, - 1.0, 2 * head_num * seq_len, 0, seq_len, @@ -181,17 +171,6 @@ void TestBlocks(const thrust::universal_vector& k_cache, // [B, H, S, k_ptrs.data().get(), cu_seq_lens.data().get(), cu_block_cnts.data().get(), - nullptr, - rope_dim, - 1., - 0., - 0., - 1.0, - 1.0, - 0.0, - 0.0, - 0.0, - 1.0, 2 * head_num * seq_len, 0, seq_len, @@ -436,9 +415,7 @@ int test_attention() params.size_per_head = kHeadDim; params.inv_sqrt_dh = (float)std::log2(expf(1.)) / std::sqrt((float)params.size_per_head); - params.rotary_embedding_dim = kRoPEDim; - params.rotary_embedding_base = kRoPEBase; - params.rope_ti_scale = 1.; + params.rotary_embedding_dim = kRoPEDim; params.split_cnt = split_cnt.data().get(); params.partial_L = partial_L.data().get(); @@ -545,17 +522,6 @@ int test_attention() k_ptrs.data().get(), cu_kv_lens.data().get(), cu_block_cnts.data().get(), - nullptr, // DECODING ? nullptr : params.rope_theta, - kRoPEDim, - 1., - 0., - 0., - 1.0, - 1.0, - 0.0, - 0.0, - 0.0, - 1.0, KvHeadNum * kContextLen, 0, kContextLen, diff --git a/src/turbomind/models/llama/unified_attention_layer.cc b/src/turbomind/models/llama/unified_attention_layer.cc index c80c1cb6b..9e3b70413 100644 --- a/src/turbomind/models/llama/unified_attention_layer.cc +++ b/src/turbomind/models/llama/unified_attention_layer.cc @@ -295,54 +295,11 @@ inline void UnifiedAttentionLayer::forward(TensorMap* outputs, const TensorMa // MSVC does not have M_LOG2E params.inv_sqrt_dh = (float)std::log2(expf(1.)) / std::sqrt((float)params.size_per_head); + // rope params.rotary_embedding_dim = param_.rotary_embedding_dim; - params.rotary_embedding_base = param_.rotary_embedding_base; params.max_position_embeddings = param_.max_position_embeddings; - params.rope_scaling_factor = param_.rope_scaling_factor; - params.attention_scaling = 1.0; - params.rope_ti_scale = 1.f; - if (param_.rope_scaling_type == "linear") { - params.rope_ti_scale /= param_.rope_scaling_factor; - } - if (param_.rope_scaling_type == "llama3") { - const double PI = 3.14159265358979323846; - float inv_diff_freq_factor = 1.0 / (param_.high_freq_factor - param_.low_freq_factor); - params.llama3_inv_scaling_factor = 1.0 / param_.rope_scaling_factor; - params.llama3_alpha = param_.original_max_position_embeddings / (2 * PI) * inv_diff_freq_factor; - params.llama3_beta = param_.low_freq_factor * inv_diff_freq_factor; - } - if (param_.rope_scaling_type == "yarn") { - const double PI = 3.14159265358979323846; - auto find_correction_dim = [&](float num_rotations) { - return (param_.rotary_embedding_dim - * std::log(param_.max_position_embeddings / (num_rotations * 2 * PI))) - / (2 * std::log(param_.rotary_embedding_base)); - }; - auto find_correction_range = [&](float low_rot, float high_rot, float& low, float& high) { - low = std::floor(find_correction_dim(low_rot)); - high = std::ceil(find_correction_dim(high_rot)); - low = std::max(low, 0.f); - high = std::min(high, param_.rotary_embedding_dim - 1.f); - }; - float low, high; - find_correction_range(param_.beta_fast, param_.beta_slow, low, high); - if (low == high) { - high += 0.01f; - } - params.yarn_ramp_inv_factor_div_2 = 1.0 / (high - low) / 2.0; - params.yarn_ramp_inv_factor_mul_min = 1.0 / (high - low) * low; - params.yarn_inv_scaling_factor = (1 - 1.0 / param_.rope_scaling_factor); - if (param_.attention_factor < 0) { - params.attention_scaling = 0.1 * std::log(param_.rope_scaling_factor) + 1.0; - } - else { - params.attention_scaling = param_.attention_factor; - } - } - - params.cos_sin = cos_sin; - - params.use_logn_attn = param_.use_logn_attn; + params.cos_sin = cos_sin; + params.use_logn_attn = param_.use_logn_attn; // Decoding use only for now FT_CHECK(barriers_); From 0e4c315b1f4075835ad0ff23a91ae68173092a42 Mon Sep 17 00:00:00 2001 From: irexyc Date: Mon, 2 Dec 2024 10:52:58 +0000 Subject: [PATCH 3/5] split rope params --- src/turbomind/models/llama/LlamaBatch.cc | 10 +- src/turbomind/models/llama/llama_params.h | 53 +++++++--- src/turbomind/models/llama/rotary_emb.cu | 99 ++++++++++--------- src/turbomind/models/llama/rotary_emb.h | 49 +++++---- .../models/llama/unified_attention_layer.cc | 4 +- src/turbomind/models/llama/unified_decoder.cc | 6 +- .../triton_backend/llama/LlamaTritonModel.cc | 41 ++++---- 7 files changed, 144 insertions(+), 118 deletions(-) diff --git a/src/turbomind/models/llama/LlamaBatch.cc b/src/turbomind/models/llama/LlamaBatch.cc index ea321d06a..9ea918787 100644 --- a/src/turbomind/models/llama/LlamaBatch.cc +++ b/src/turbomind/models/llama/LlamaBatch.cc @@ -368,15 +368,15 @@ void LlamaBatch::ProcessInferRequests(const Requests& requests) // compute rope scaling factor if (r->start_flag) { - seq.rope_theta = model_->attn_param_.rotary_embedding_base; - if (model_->attn_param_.use_dynamic_ntk) { - auto scaling_factor = model_->attn_param_.rope_scaling_factor; + seq.rope_theta = model_->attn_param_.rope.base; + if (model_->attn_param_.rope.type == RotaryScalingType::kDynamic) { + auto scaling_factor = model_->attn_param_.rope.factor; if (scaling_factor >= 1.f) { // infer by current context length auto max_seq_len = state.h_context_length[idx]; - auto max_pos_emb = model_->attn_param_.max_position_embeddings; + auto max_pos_emb = model_->attn_param_.rope.max_position_embeddings; if (max_seq_len > max_pos_emb) { scaling_factor = scaling_factor * max_seq_len / max_pos_emb - (scaling_factor - 1); - float rope_dim = model_->attn_param_.rotary_embedding_dim; + float rope_dim = model_->attn_param_.rope.dim; seq.rope_theta *= powf(scaling_factor, rope_dim / (rope_dim - 2.f)); TM_LOG_INFO("[ProcessInferRequests] %ld rope_scaling_factor: %f, rope_theta = %f", (long)seq.id, diff --git a/src/turbomind/models/llama/llama_params.h b/src/turbomind/models/llama/llama_params.h index 0a505b11a..000ef82ef 100644 --- a/src/turbomind/models/llama/llama_params.h +++ b/src/turbomind/models/llama/llama_params.h @@ -59,22 +59,45 @@ struct MoeParam { std::vector expert_num; }; +enum class RotaryScalingType +{ + kDefault, + kLinear, + kDynamic, + kYarn, + kLlama3, +}; + +struct YarnRopeParam { + float attention_factor; + float beta_fast; + float beta_slow; +}; + +struct Llama3RopeParam { + float low_freq_factor; + float high_freq_factor; + int original_max_position_embeddings; +}; + struct AttentionParam { - int rotary_embedding_dim; - float rotary_embedding_base; - int max_position_embeddings; - float softmax_scale; - std::string rope_scaling_type; - int original_max_position_embeddings; - float rope_scaling_factor; - float low_freq_factor; - float high_freq_factor; - float attention_factor; - float beta_fast; - float beta_slow; - bool use_dynamic_ntk; - bool use_logn_attn; - int cache_block_seq_len; + float softmax_scale; + int cache_block_seq_len; + bool use_logn_attn; + // rope + struct { + // common + RotaryScalingType type; + int dim; + float base; + float factor; + int max_position_embeddings; + // special + union { + YarnRopeParam yarn; + Llama3RopeParam llama3; + }; + } rope; }; struct EngineParam { diff --git a/src/turbomind/models/llama/rotary_emb.cu b/src/turbomind/models/llama/rotary_emb.cu index 2ecec40a7..a0e119062 100644 --- a/src/turbomind/models/llama/rotary_emb.cu +++ b/src/turbomind/models/llama/rotary_emb.cu @@ -114,8 +114,7 @@ RotaryScalingType GetRoPEType(const std::string& type) {"linear", RotaryScalingType::kLinear}, {"dynamic", RotaryScalingType::kDynamic}, {"yarn", RotaryScalingType::kYarn}, - {"llama3", RotaryScalingType::kLlama3}, - {"mrope", RotaryScalingType::kMrope}}; + {"llama3", RotaryScalingType::kLlama3}}; return lookup.at(type); } @@ -132,42 +131,52 @@ void RotaryEmbeddingV2::allocateBuffer(size_t token_num) RotaryEmbeddingV2::RotaryEmbeddingV2(const AttentionParam& param, cudaStream_t stream, IAllocator* allocator): stream_(stream), allocator_(allocator) { - type_ = GetRoPEType(param.rope_scaling_type); - dim_ = param.rotary_embedding_dim; - rope_scaling_factor_ = 1.0f; - attention_factor_ = 1.0f; + type_ = param.rope.type; + dim_ = param.rope.dim; - if (type_ == RotaryScalingType::kLinear) { - rope_scaling_factor_ /= param.rope_scaling_factor; - } - else if (type_ == RotaryScalingType::kLlama3) { - const double PI = 3.14159265358979323846; - float inv_diff_freq_factor = 1.0 / (param.high_freq_factor - param.low_freq_factor); - llama3_inv_scaling_factor_ = 1.0 / param.rope_scaling_factor; - llama3_alpha_ = param.original_max_position_embeddings / (2 * PI) * inv_diff_freq_factor; - llama3_beta_ = param.low_freq_factor * inv_diff_freq_factor; - } - else if (type_ == RotaryScalingType::kYarn) { - const double PI = 3.14159265358979323846; - auto find_correction_dim = [&](float num_rotations) { - return (param.rotary_embedding_dim * std::log(param.max_position_embeddings / (num_rotations * 2 * PI))) - / (2 * std::log(param.rotary_embedding_base)); - }; - auto find_correction_range = [&](float low_rot, float high_rot, float& low, float& high) { - low = std::floor(find_correction_dim(low_rot)); - high = std::ceil(find_correction_dim(high_rot)); - low = std::max(low, 0.f); - high = std::min(high, param.rotary_embedding_dim - 1.f); - }; - float low, high; - find_correction_range(param.beta_fast, param.beta_slow, low, high); - if (low == high) { - high += 0.01f; + switch (type_) { + case RotaryScalingType::kDefault: + break; + case RotaryScalingType::kLinear: + inv_factor_ = 1.0f / param.rope.factor; + break; + case RotaryScalingType::kDynamic: + inv_factor_ = param.rope.factor; + break; + case RotaryScalingType::kYarn: { + const double PI = 3.14159265358979323846; + auto find_correction_dim = [&](float num_rotations) { + return (param.rope.dim * std::log(param.rope.max_position_embeddings / (num_rotations * 2 * PI))) + / (2 * std::log(param.rope.base)); + }; + auto find_correction_range = [&](float low_rot, float high_rot, float& low, float& high) { + low = std::floor(find_correction_dim(low_rot)); + high = std::ceil(find_correction_dim(high_rot)); + low = std::max(low, 0.f); + high = std::min(high, param.rope.dim - 1.f); + }; + float low, high; + find_correction_range(param.rope.yarn.beta_fast, param.rope.yarn.beta_slow, low, high); + if (low == high) { + high += 0.01f; + } + yarn_.yarn_ramp_inv_factor_div_2 = 1.0 / (high - low) / 2.0; + yarn_.yarn_ramp_inv_factor_mul_min = 1.0 / (high - low) * low; + yarn_.yarn_inv_scaling_factor = (1 - 1.0 / param.rope.factor); + yarn_.attention_factor = param.rope.yarn.attention_factor; + break; + } + case RotaryScalingType::kLlama3: { + const double PI = 3.14159265358979323846; + float inv_diff_freq_factor = 1.0 / (param.rope.llama3.high_freq_factor - param.rope.llama3.low_freq_factor); + llama3_.llama3_inv_scaling_factor = 1.0 / param.rope.factor; + llama3_.llama3_alpha = param.rope.llama3.original_max_position_embeddings / (2 * PI) * inv_diff_freq_factor; + llama3_.llama3_beta = param.rope.llama3.low_freq_factor * inv_diff_freq_factor; + break; } - yarn_ramp_inv_factor_div_2_ = 1.0 / (high - low) / 2.0; - yarn_ramp_inv_factor_mul_min_ = 1.0 / (high - low) * low; - yarn_inv_scaling_factor_ = (1 - 1.0 / param.rope_scaling_factor); - attention_factor_ = param.attention_factor; + default: + FT_CHECK(0); + break; } } @@ -188,7 +197,7 @@ void RotaryEmbeddingV2::forward(const RotaryEmbeddingV2Params& params) params.token_num, params.batch_size, dim_, - rope_scaling_factor_, + inv_factor_, cos_sin_); break; case RotaryScalingType::kLlama3: @@ -198,9 +207,9 @@ void RotaryEmbeddingV2::forward(const RotaryEmbeddingV2Params& params) params.token_num, params.batch_size, dim_, - llama3_inv_scaling_factor_, - llama3_alpha_, - llama3_beta_, + llama3_.llama3_inv_scaling_factor, + llama3_.llama3_alpha, + llama3_.llama3_beta, cos_sin_); break; case RotaryScalingType::kYarn: @@ -210,14 +219,12 @@ void RotaryEmbeddingV2::forward(const RotaryEmbeddingV2Params& params) params.token_num, params.batch_size, dim_, - yarn_ramp_inv_factor_div_2_, - yarn_ramp_inv_factor_mul_min_, - yarn_inv_scaling_factor_, - attention_factor_, + yarn_.yarn_ramp_inv_factor_div_2, + yarn_.yarn_ramp_inv_factor_mul_min, + yarn_.yarn_inv_scaling_factor, + yarn_.attention_factor, cos_sin_); break; - case RotaryScalingType::kMrope: - FT_CHECK(0); default: FT_CHECK(0); } diff --git a/src/turbomind/models/llama/rotary_emb.h b/src/turbomind/models/llama/rotary_emb.h index ffe81752e..66a830b01 100644 --- a/src/turbomind/models/llama/rotary_emb.h +++ b/src/turbomind/models/llama/rotary_emb.h @@ -5,15 +5,7 @@ namespace turbomind { -enum class RotaryScalingType -{ - kDefault, - kLinear, - kDynamic, - kYarn, - kLlama3, - kMrope -}; +RotaryScalingType GetRoPEType(const std::string& type); struct RotaryEmbeddingV2Params { float* rope_theta; @@ -23,6 +15,19 @@ struct RotaryEmbeddingV2Params { int token_num; }; +struct InnerYarnRopeParam { + float attention_factor; + float yarn_ramp_inv_factor_div_2; + float yarn_ramp_inv_factor_mul_min; + float yarn_inv_scaling_factor; +}; + +struct InnerLlama3RopeParam { + float llama3_inv_scaling_factor; + float llama3_alpha; + float llama3_beta; +}; + struct RotaryEmbeddingV2 { RotaryEmbeddingV2(const AttentionParam& param, cudaStream_t stream, IAllocator* allocator); @@ -38,28 +43,20 @@ struct RotaryEmbeddingV2 { void forward(const RotaryEmbeddingV2Params& params); - RotaryScalingType type_; cudaStream_t const stream_; IAllocator* const allocator_; + int dim_; + RotaryScalingType type_; + float inv_factor_{1.0}; + + union { + InnerYarnRopeParam yarn_; + InnerLlama3RopeParam llama3_; + }; + // output float* cos_sin_; // num_token x dim, (cos, sin, ...) - - int dim_; - // default, linear, dynamic - float attention_factor_; - float rope_scaling_factor_; - float inv_scale_factor_; - // llama3 - float llama3_inv_scaling_factor_; - float llama3_alpha_; - float llama3_beta_; - // yarn - float yarn_ramp_inv_factor_div_2_; - float yarn_ramp_inv_factor_mul_min_; - float yarn_inv_scaling_factor_; - // mrope - int3 mrope_section_; }; }; // namespace turbomind diff --git a/src/turbomind/models/llama/unified_attention_layer.cc b/src/turbomind/models/llama/unified_attention_layer.cc index 89224e853..77d53afd5 100644 --- a/src/turbomind/models/llama/unified_attention_layer.cc +++ b/src/turbomind/models/llama/unified_attention_layer.cc @@ -313,8 +313,8 @@ inline void UnifiedAttentionLayer::forward(TensorMap* outputs, const TensorMa } // rope - params.rotary_embedding_dim = param_.rotary_embedding_dim; - params.max_position_embeddings = param_.max_position_embeddings; + params.rotary_embedding_dim = param_.rope.dim; + params.max_position_embeddings = param_.rope.max_position_embeddings; params.cos_sin = cos_sin; params.use_logn_attn = param_.use_logn_attn; diff --git a/src/turbomind/models/llama/unified_decoder.cc b/src/turbomind/models/llama/unified_decoder.cc index e40d7af22..c37658fd3 100644 --- a/src/turbomind/models/llama/unified_decoder.cc +++ b/src/turbomind/models/llama/unified_decoder.cc @@ -88,11 +88,7 @@ void UnifiedDecoder::forwardSelfAttn(T* attn_io, inputs.insert("cu_k_len", {MEMORY_GPU, TYPE_INT32, {batch_size + 1}, cu_k_len_}); inputs.insert("h_cu_q_len", {MEMORY_CPU, TYPE_INT32, {batch_size + 1}, h_cu_q_len_}); inputs.insert("h_cu_k_len", {MEMORY_CPU, TYPE_INT32, {batch_size + 1}, h_cu_k_len_}); - - if (rotary_emb_) { - inputs.insert("cos_sin", - {MEMORY_GPU, TYPE_FP32, {token_num, (size_t)rotary_emb_->dim_}, rotary_emb_->cos_sin_}); - } + inputs.insert("cos_sin", {MEMORY_GPU, TYPE_FP32, {token_num, (size_t)rotary_emb_->dim_}, rotary_emb_->cos_sin_}); TensorMap outputs(*_outputs); outputs.insert("hidden_features", {MEMORY_GPU, dtype_, {token_num, hidden_units_}, attn_io}); diff --git a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc index 40c5ac890..f2fa583c9 100644 --- a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc +++ b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc @@ -140,10 +140,10 @@ void LlamaTritonModel::handleMissingParams() (int)model_param_.vocab_size); } - if (!attn_param_.max_position_embeddings) { - attn_param_.max_position_embeddings = 2048; + if (!attn_param_.rope.max_position_embeddings) { + attn_param_.rope.max_position_embeddings = 2048; TM_LOG_WARNING("[LlamaTritonModel] `max_position_embeddings` is not set, default to %d.", - (int)attn_param_.max_position_embeddings); + (int)attn_param_.rope.max_position_embeddings); } if (!engine_param_.max_batch_size) { @@ -153,7 +153,7 @@ void LlamaTritonModel::handleMissingParams() } if (!engine_param_.session_len) { - engine_param_.session_len = attn_param_.max_position_embeddings; + engine_param_.session_len = attn_param_.rope.max_position_embeddings; TM_LOG_WARNING("[LlamaTritonModel] `session_len` is not set, default to %d.", (int)engine_param_.session_len); } @@ -277,22 +277,25 @@ LlamaTritonModel::LlamaTritonModel(size_t tensor_para_size, model_param_.attn_bias = model_reader["attn_bias"].as(0); model_param_.group_size = model_reader["group_size"].as(0); + attn_param_.softmax_scale = attention_reader["softmax_scale"].as(0); + attn_param_.use_logn_attn = attention_reader["use_logn_attn"].as(0); // rotary embedding parameters - attn_param_.rotary_embedding_dim = attention_reader["rotary_embedding"].as(); - attn_param_.rotary_embedding_base = attention_reader["rope_theta"].as(10000.0f); - attn_param_.softmax_scale = attention_reader["softmax_scale"].as(0); - attn_param_.attention_factor = attention_reader["attention_factor"].as(-1.f); - attn_param_.beta_fast = attention_reader["beta_fast"].as(32.f); - attn_param_.beta_slow = attention_reader["beta_slow"].as(1.f); - attn_param_.rope_scaling_type = attention_reader["rope_scaling_type"].as(""); - attn_param_.rope_scaling_factor = attention_reader["rope_scaling_factor"].as(0.f); - attn_param_.low_freq_factor = attention_reader["low_freq_factor"].as(1.0); - attn_param_.high_freq_factor = attention_reader["high_freq_factor"].as(1.0); - attn_param_.max_position_embeddings = attention_reader["max_position_embeddings"].as(0); - attn_param_.use_dynamic_ntk = attention_reader["use_dynamic_ntk"].as(0); - attn_param_.use_logn_attn = attention_reader["use_logn_attn"].as(0); - - attn_param_.original_max_position_embeddings = attention_reader["original_max_position_embeddings"].as(0); + attn_param_.rope.type = GetRoPEType(attention_reader["rope_scaling_type"].as("")); + attn_param_.rope.dim = attention_reader["rotary_embedding"].as(); + attn_param_.rope.base = attention_reader["rope_theta"].as(10000.0f); + attn_param_.rope.max_position_embeddings = attention_reader["max_position_embeddings"].as(0); + attn_param_.rope.factor = attention_reader["rope_scaling_factor"].as(0.f); + if (attn_param_.rope.type == RotaryScalingType::kYarn) { + attn_param_.rope.yarn.attention_factor = attention_reader["attention_factor"].as(-1.f); + attn_param_.rope.yarn.beta_fast = attention_reader["beta_fast"].as(32.f); + attn_param_.rope.yarn.beta_slow = attention_reader["beta_slow"].as(1.f); + } + else if (attn_param_.rope.type == RotaryScalingType::kLlama3) { + attn_param_.rope.llama3.low_freq_factor = attention_reader["low_freq_factor"].as(1.0); + attn_param_.rope.llama3.high_freq_factor = attention_reader["high_freq_factor"].as(1.0); + attn_param_.rope.llama3.original_max_position_embeddings = + attention_reader["original_max_position_embeddings"].as(0); + } engine_param_.max_batch_size = engine_reader["max_batch_size"].as(0); engine_param_.max_prefill_token_num = engine_reader["max_prefill_token_num"].as(0); From ea6112e400f7dd0fa1fb68af528a510ca5048a15 Mon Sep 17 00:00:00 2001 From: irexyc Date: Mon, 2 Dec 2024 12:07:51 +0000 Subject: [PATCH 4/5] remove prefix yarn_, llama3_ --- .../kernels/attention/test_attention.cu | 4 +- src/turbomind/models/llama/LlamaBatch.cc | 2 +- src/turbomind/models/llama/llama_params.h | 12 +-- src/turbomind/models/llama/rotary_emb.cu | 78 +++++++++---------- src/turbomind/models/llama/rotary_emb.h | 20 ++--- .../triton_backend/llama/LlamaTritonModel.cc | 4 +- 6 files changed, 59 insertions(+), 61 deletions(-) diff --git a/src/turbomind/kernels/attention/test_attention.cu b/src/turbomind/kernels/attention/test_attention.cu index fc060cfec..16b7846c7 100644 --- a/src/turbomind/kernels/attention/test_attention.cu +++ b/src/turbomind/kernels/attention/test_attention.cu @@ -429,9 +429,7 @@ int test_attention() params.qk = qk_buf.data().get(); params.pr = pr_buf.data().get(); - params.attention_scaling = 1.f; - params.llama3_inv_scaling_factor = 0; - params.yarn_ramp_inv_factor_div_2 = 0; + params.cos_sin = nullptr; Reference reference(kDump ? Reference::kUNFUSED : Reference::kFLASH_ATTENTION, {}); // Reference reference(Reference::kUNFUSED, {}); diff --git a/src/turbomind/models/llama/LlamaBatch.cc b/src/turbomind/models/llama/LlamaBatch.cc index 9ea918787..9ef669f8d 100644 --- a/src/turbomind/models/llama/LlamaBatch.cc +++ b/src/turbomind/models/llama/LlamaBatch.cc @@ -369,7 +369,7 @@ void LlamaBatch::ProcessInferRequests(const Requests& requests) // compute rope scaling factor if (r->start_flag) { seq.rope_theta = model_->attn_param_.rope.base; - if (model_->attn_param_.rope.type == RotaryScalingType::kDynamic) { + if (model_->attn_param_.rope.type == RopeType::kDynamic) { auto scaling_factor = model_->attn_param_.rope.factor; if (scaling_factor >= 1.f) { // infer by current context length auto max_seq_len = state.h_context_length[idx]; diff --git a/src/turbomind/models/llama/llama_params.h b/src/turbomind/models/llama/llama_params.h index 000ef82ef..ff0482eef 100644 --- a/src/turbomind/models/llama/llama_params.h +++ b/src/turbomind/models/llama/llama_params.h @@ -59,7 +59,7 @@ struct MoeParam { std::vector expert_num; }; -enum class RotaryScalingType +enum class RopeType { kDefault, kLinear, @@ -87,11 +87,11 @@ struct AttentionParam { // rope struct { // common - RotaryScalingType type; - int dim; - float base; - float factor; - int max_position_embeddings; + RopeType type; + int dim; + float base; + float factor; + int max_position_embeddings; // special union { YarnRopeParam yarn; diff --git a/src/turbomind/models/llama/rotary_emb.cu b/src/turbomind/models/llama/rotary_emb.cu index a0e119062..72362bcb7 100644 --- a/src/turbomind/models/llama/rotary_emb.cu +++ b/src/turbomind/models/llama/rotary_emb.cu @@ -55,9 +55,9 @@ __global__ void computeCosSinLlama3(const float* rope_base, int token_num, int batch_size, int dim, - float llama3_inv_scaling_factor, - float llama3_alpha, - float llama3_beta, + float inv_scaling_factor, + float alpha, + float beta, float* cos_sin) { int qi = blockIdx.x; @@ -69,8 +69,8 @@ __global__ void computeCosSinLlama3(const float* rope_base, float ti = history_len + qi - q_len[bid - 1]; float inv_freq = compute_default_parameters(base, dim, di * 2, 1.0f); - auto smooth = fmaxf(0.f, fminf(1.f, llama3_alpha * inv_freq - llama3_beta)); - inv_freq = (1 - smooth) * inv_freq * llama3_inv_scaling_factor + smooth * inv_freq; + auto smooth = fmaxf(0.f, fminf(1.f, alpha * inv_freq - beta)); + inv_freq = (1 - smooth) * inv_freq * inv_scaling_factor + smooth * inv_freq; float c, s; sincosf(ti * inv_freq, &s, &c); (float2&)cos_sin[dim * qi + 2 * di] = {c, s}; @@ -82,9 +82,9 @@ __global__ void computeCosSinYarn(const float* rope_base, int token_num, int batch_size, int dim, - float yarn_ramp_inv_factor_div_2, - float yarn_ramp_inv_factor_mul_min, - float yarn_inv_scaling_factor, + float ramp_inv_factor_div_2, + float ramp_inv_factor_mul_min, + float inv_scaling_factor, float attention_scaling, float* cos_sin) { @@ -97,9 +97,9 @@ __global__ void computeCosSinYarn(const float* rope_base, float ti = history_len + qi - q_len[bid - 1]; float inv_freq = compute_default_parameters(base, dim, di * 2, 1.0f); - float alpha = 2 * di * yarn_ramp_inv_factor_div_2 - yarn_ramp_inv_factor_mul_min; + float alpha = 2 * di * ramp_inv_factor_div_2 - ramp_inv_factor_mul_min; alpha = fmaxf(0.f, fminf(1.f, alpha)); - inv_freq = inv_freq - inv_freq * alpha * yarn_inv_scaling_factor; + inv_freq = inv_freq - inv_freq * alpha * inv_scaling_factor; float c, s; sincosf(ti * inv_freq, &s, &c); @@ -108,13 +108,13 @@ __global__ void computeCosSinYarn(const float* rope_base, (float2&)cos_sin[dim * qi + 2 * di] = {c, s}; } -RotaryScalingType GetRoPEType(const std::string& type) +RopeType GetRoPEType(const std::string& type) { - std::map lookup = {{"", RotaryScalingType::kDefault}, - {"linear", RotaryScalingType::kLinear}, - {"dynamic", RotaryScalingType::kDynamic}, - {"yarn", RotaryScalingType::kYarn}, - {"llama3", RotaryScalingType::kLlama3}}; + std::map lookup = {{"", RopeType::kDefault}, + {"linear", RopeType::kLinear}, + {"dynamic", RopeType::kDynamic}, + {"yarn", RopeType::kYarn}, + {"llama3", RopeType::kLlama3}}; return lookup.at(type); } @@ -135,15 +135,15 @@ RotaryEmbeddingV2::RotaryEmbeddingV2(const AttentionParam& param, cudaStream_t s dim_ = param.rope.dim; switch (type_) { - case RotaryScalingType::kDefault: + case RopeType::kDefault: break; - case RotaryScalingType::kLinear: + case RopeType::kLinear: inv_factor_ = 1.0f / param.rope.factor; break; - case RotaryScalingType::kDynamic: + case RopeType::kDynamic: inv_factor_ = param.rope.factor; break; - case RotaryScalingType::kYarn: { + case RopeType::kYarn: { const double PI = 3.14159265358979323846; auto find_correction_dim = [&](float num_rotations) { return (param.rope.dim * std::log(param.rope.max_position_embeddings / (num_rotations * 2 * PI))) @@ -160,18 +160,18 @@ RotaryEmbeddingV2::RotaryEmbeddingV2(const AttentionParam& param, cudaStream_t s if (low == high) { high += 0.01f; } - yarn_.yarn_ramp_inv_factor_div_2 = 1.0 / (high - low) / 2.0; - yarn_.yarn_ramp_inv_factor_mul_min = 1.0 / (high - low) * low; - yarn_.yarn_inv_scaling_factor = (1 - 1.0 / param.rope.factor); - yarn_.attention_factor = param.rope.yarn.attention_factor; + yarn_.ramp_inv_factor_div_2 = 1.0 / (high - low) / 2.0; + yarn_.ramp_inv_factor_mul_min = 1.0 / (high - low) * low; + yarn_.inv_scaling_factor = (1 - 1.0 / param.rope.factor); + yarn_.attention_factor = param.rope.yarn.attention_factor; break; } - case RotaryScalingType::kLlama3: { + case RopeType::kLlama3: { const double PI = 3.14159265358979323846; float inv_diff_freq_factor = 1.0 / (param.rope.llama3.high_freq_factor - param.rope.llama3.low_freq_factor); - llama3_.llama3_inv_scaling_factor = 1.0 / param.rope.factor; - llama3_.llama3_alpha = param.rope.llama3.original_max_position_embeddings / (2 * PI) * inv_diff_freq_factor; - llama3_.llama3_beta = param.rope.llama3.low_freq_factor * inv_diff_freq_factor; + llama3_.inv_scaling_factor = 1.0 / param.rope.factor; + llama3_.alpha = param.rope.llama3.original_max_position_embeddings / (2 * PI) * inv_diff_freq_factor; + llama3_.beta = param.rope.llama3.low_freq_factor * inv_diff_freq_factor; break; } default: @@ -188,9 +188,9 @@ void RotaryEmbeddingV2::forward(const RotaryEmbeddingV2Params& params) const int block = dim_ / 2; switch (type_) { - case RotaryScalingType::kDefault: - case RotaryScalingType::kLinear: - case RotaryScalingType::kDynamic: + case RopeType::kDefault: + case RopeType::kLinear: + case RopeType::kDynamic: computeCosSinDefault<<>>(params.rope_theta, params.q_len, params.k_ken, @@ -200,28 +200,28 @@ void RotaryEmbeddingV2::forward(const RotaryEmbeddingV2Params& params) inv_factor_, cos_sin_); break; - case RotaryScalingType::kLlama3: + case RopeType::kLlama3: computeCosSinLlama3<<>>(params.rope_theta, params.q_len, params.k_ken, params.token_num, params.batch_size, dim_, - llama3_.llama3_inv_scaling_factor, - llama3_.llama3_alpha, - llama3_.llama3_beta, + llama3_.inv_scaling_factor, + llama3_.alpha, + llama3_.beta, cos_sin_); break; - case RotaryScalingType::kYarn: + case RopeType::kYarn: computeCosSinYarn<<>>(params.rope_theta, params.q_len, params.k_ken, params.token_num, params.batch_size, dim_, - yarn_.yarn_ramp_inv_factor_div_2, - yarn_.yarn_ramp_inv_factor_mul_min, - yarn_.yarn_inv_scaling_factor, + yarn_.ramp_inv_factor_div_2, + yarn_.ramp_inv_factor_mul_min, + yarn_.inv_scaling_factor, yarn_.attention_factor, cos_sin_); break; diff --git a/src/turbomind/models/llama/rotary_emb.h b/src/turbomind/models/llama/rotary_emb.h index 66a830b01..43fa24714 100644 --- a/src/turbomind/models/llama/rotary_emb.h +++ b/src/turbomind/models/llama/rotary_emb.h @@ -5,7 +5,7 @@ namespace turbomind { -RotaryScalingType GetRoPEType(const std::string& type); +RopeType GetRoPEType(const std::string& type); struct RotaryEmbeddingV2Params { float* rope_theta; @@ -17,15 +17,15 @@ struct RotaryEmbeddingV2Params { struct InnerYarnRopeParam { float attention_factor; - float yarn_ramp_inv_factor_div_2; - float yarn_ramp_inv_factor_mul_min; - float yarn_inv_scaling_factor; + float ramp_inv_factor_div_2; + float ramp_inv_factor_mul_min; + float inv_scaling_factor; }; struct InnerLlama3RopeParam { - float llama3_inv_scaling_factor; - float llama3_alpha; - float llama3_beta; + float inv_scaling_factor; + float alpha; + float beta; }; struct RotaryEmbeddingV2 { @@ -46,9 +46,9 @@ struct RotaryEmbeddingV2 { cudaStream_t const stream_; IAllocator* const allocator_; - int dim_; - RotaryScalingType type_; - float inv_factor_{1.0}; + int dim_; + RopeType type_; + float inv_factor_{1.0}; union { InnerYarnRopeParam yarn_; diff --git a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc index f2fa583c9..f95ee8410 100644 --- a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc +++ b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc @@ -285,12 +285,12 @@ LlamaTritonModel::LlamaTritonModel(size_t tensor_para_size, attn_param_.rope.base = attention_reader["rope_theta"].as(10000.0f); attn_param_.rope.max_position_embeddings = attention_reader["max_position_embeddings"].as(0); attn_param_.rope.factor = attention_reader["rope_scaling_factor"].as(0.f); - if (attn_param_.rope.type == RotaryScalingType::kYarn) { + if (attn_param_.rope.type == RopeType::kYarn) { attn_param_.rope.yarn.attention_factor = attention_reader["attention_factor"].as(-1.f); attn_param_.rope.yarn.beta_fast = attention_reader["beta_fast"].as(32.f); attn_param_.rope.yarn.beta_slow = attention_reader["beta_slow"].as(1.f); } - else if (attn_param_.rope.type == RotaryScalingType::kLlama3) { + else if (attn_param_.rope.type == RopeType::kLlama3) { attn_param_.rope.llama3.low_freq_factor = attention_reader["low_freq_factor"].as(1.0); attn_param_.rope.llama3.high_freq_factor = attention_reader["high_freq_factor"].as(1.0); attn_param_.rope.llama3.original_max_position_embeddings = From 0513e12ba7d77128f296318bf53e3c0de336d33f Mon Sep 17 00:00:00 2001 From: irexyc Date: Mon, 2 Dec 2024 13:06:24 +0000 Subject: [PATCH 5/5] fix test_attention --- .../kernels/attention/test_attention.cu | 23 +++++++++++++++++-- src/turbomind/models/llama/rotary_emb.cu | 2 +- src/turbomind/models/llama/rotary_emb.h | 4 ++-- src/turbomind/models/llama/unified_decoder.cc | 2 +- 4 files changed, 25 insertions(+), 6 deletions(-) diff --git a/src/turbomind/kernels/attention/test_attention.cu b/src/turbomind/kernels/attention/test_attention.cu index 16b7846c7..788b5d489 100644 --- a/src/turbomind/kernels/attention/test_attention.cu +++ b/src/turbomind/kernels/attention/test_attention.cu @@ -7,6 +7,7 @@ #include "src/turbomind/kernels/attention/attention_params.h" #include "src/turbomind/kernels/attention/reference.h" #include "src/turbomind/models/llama/llama_utils.h" +#include "src/turbomind/models/llama/rotary_emb.h" #include "src/turbomind/utils/cuda_utils.h" #include "test_utils.h" #include @@ -376,6 +377,26 @@ int test_attention() rope_base[i] = kRoPEBase; } + // precompute cos/sin + const int device_id = 0; + auto allocator = std::make_unique>(device_id, false); + allocator->setStream(nullptr); + AttentionParam attn_param; + attn_param.rope.type = RopeType::kDefault; + attn_param.rope.base = kRoPEBase; + attn_param.rope.dim = kRoPEDim; + attn_param.rope.factor = 1.0f; + auto rotary_emb = std::make_unique(attn_param, nullptr, allocator.get()); + + RotaryEmbeddingV2Param rotary_param; + rotary_param.rope_theta = rope_base.data().get(); + rotary_param.q_len = cu_seqlens.data().get(); + rotary_param.k_ken = cu_kv_lens.data().get(); + rotary_param.batch_size = kBatchSize; + rotary_param.token_num = kTokenNum; + rotary_emb->forward(rotary_param); + params.cos_sin = rotary_emb->cos_sin_; + // getchar(); params.out = output_ref.data().get(); @@ -429,8 +450,6 @@ int test_attention() params.qk = qk_buf.data().get(); params.pr = pr_buf.data().get(); - params.cos_sin = nullptr; - Reference reference(kDump ? Reference::kUNFUSED : Reference::kFLASH_ATTENTION, {}); // Reference reference(Reference::kUNFUSED, {}); reference.Reshape(kInputLen, kContextLen, kHeadNum, kHeadDim, KvHeadNum, kBatchSize); diff --git a/src/turbomind/models/llama/rotary_emb.cu b/src/turbomind/models/llama/rotary_emb.cu index 72362bcb7..efb98bf85 100644 --- a/src/turbomind/models/llama/rotary_emb.cu +++ b/src/turbomind/models/llama/rotary_emb.cu @@ -180,7 +180,7 @@ RotaryEmbeddingV2::RotaryEmbeddingV2(const AttentionParam& param, cudaStream_t s } } -void RotaryEmbeddingV2::forward(const RotaryEmbeddingV2Params& params) +void RotaryEmbeddingV2::forward(const RotaryEmbeddingV2Param& params) { allocateBuffer(params.token_num); diff --git a/src/turbomind/models/llama/rotary_emb.h b/src/turbomind/models/llama/rotary_emb.h index 43fa24714..db250402a 100644 --- a/src/turbomind/models/llama/rotary_emb.h +++ b/src/turbomind/models/llama/rotary_emb.h @@ -7,7 +7,7 @@ namespace turbomind { RopeType GetRoPEType(const std::string& type); -struct RotaryEmbeddingV2Params { +struct RotaryEmbeddingV2Param { float* rope_theta; int* q_len; int* k_ken; @@ -41,7 +41,7 @@ struct RotaryEmbeddingV2 { freeBuffer(); } - void forward(const RotaryEmbeddingV2Params& params); + void forward(const RotaryEmbeddingV2Param& params); cudaStream_t const stream_; IAllocator* const allocator_; diff --git a/src/turbomind/models/llama/unified_decoder.cc b/src/turbomind/models/llama/unified_decoder.cc index c37658fd3..792c4d3b5 100644 --- a/src/turbomind/models/llama/unified_decoder.cc +++ b/src/turbomind/models/llama/unified_decoder.cc @@ -164,7 +164,7 @@ void UnifiedDecoder::forward(TensorMap* outputs, const TensorMap* inputs, con count_and_fix(decoder_output, token_num * hidden_units_, Concat("norm0", 0), 2); { - RotaryEmbeddingV2Params params; + RotaryEmbeddingV2Param params; params.rope_theta = inputs->getPtr("rope_theta"); params.q_len = cu_q_len_; params.k_ken = cu_k_len_;