Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor turbomind attention by precomputing cos/sin #2801

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions src/turbomind/models/llama/LlamaBatch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -368,15 +368,15 @@ void LlamaBatch<T>::ProcessInferRequests(const Requests& requests)

// compute rope scaling factor
if (r->start_flag) {
seq.rope_theta = model_->attn_param_.rotary_embedding_base;
if (model_->attn_param_.use_dynamic_ntk) {
auto scaling_factor = model_->attn_param_.rope_scaling_factor;
seq.rope_theta = model_->attn_param_.rope.base;
if (model_->attn_param_.rope.type == RotaryScalingType::kDynamic) {
auto scaling_factor = model_->attn_param_.rope.factor;
if (scaling_factor >= 1.f) { // infer by current context length
auto max_seq_len = state.h_context_length[idx];
auto max_pos_emb = model_->attn_param_.max_position_embeddings;
auto max_pos_emb = model_->attn_param_.rope.max_position_embeddings;
if (max_seq_len > max_pos_emb) {
scaling_factor = scaling_factor * max_seq_len / max_pos_emb - (scaling_factor - 1);
float rope_dim = model_->attn_param_.rotary_embedding_dim;
float rope_dim = model_->attn_param_.rope.dim;
seq.rope_theta *= powf(scaling_factor, rope_dim / (rope_dim - 2.f));
TM_LOG_INFO("[ProcessInferRequests] %ld rope_scaling_factor: %f, rope_theta = %f",
(long)seq.id,
Expand Down
53 changes: 38 additions & 15 deletions src/turbomind/models/llama/llama_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,22 +59,45 @@ struct MoeParam {
std::vector<int> expert_num;
};

enum class RotaryScalingType
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RotaryScalingType -> RopeType

{
kDefault,
kLinear,
kDynamic,
kYarn,
kLlama3,
};

struct YarnRopeParam {
float attention_factor;
float beta_fast;
float beta_slow;
};

struct Llama3RopeParam {
float low_freq_factor;
float high_freq_factor;
int original_max_position_embeddings;
};

struct AttentionParam {
int rotary_embedding_dim;
float rotary_embedding_base;
int max_position_embeddings;
float softmax_scale;
std::string rope_scaling_type;
int original_max_position_embeddings;
float rope_scaling_factor;
float low_freq_factor;
float high_freq_factor;
float attention_factor;
float beta_fast;
float beta_slow;
bool use_dynamic_ntk;
bool use_logn_attn;
int cache_block_seq_len;
float softmax_scale;
int cache_block_seq_len;
bool use_logn_attn;
// rope
struct {
// common
RotaryScalingType type;
int dim;
float base;
float factor;
int max_position_embeddings;
// special
union {
YarnRopeParam yarn;
Llama3RopeParam llama3;
};
} rope;
};

struct EngineParam {
Expand Down
99 changes: 53 additions & 46 deletions src/turbomind/models/llama/rotary_emb.cu
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,7 @@ RotaryScalingType GetRoPEType(const std::string& type)
{"linear", RotaryScalingType::kLinear},
{"dynamic", RotaryScalingType::kDynamic},
{"yarn", RotaryScalingType::kYarn},
{"llama3", RotaryScalingType::kLlama3},
{"mrope", RotaryScalingType::kMrope}};
{"llama3", RotaryScalingType::kLlama3}};
return lookup.at(type);
}

Expand All @@ -132,42 +131,52 @@ void RotaryEmbeddingV2::allocateBuffer(size_t token_num)
RotaryEmbeddingV2::RotaryEmbeddingV2(const AttentionParam& param, cudaStream_t stream, IAllocator* allocator):
stream_(stream), allocator_(allocator)
{
type_ = GetRoPEType(param.rope_scaling_type);
dim_ = param.rotary_embedding_dim;
rope_scaling_factor_ = 1.0f;
attention_factor_ = 1.0f;
type_ = param.rope.type;
dim_ = param.rope.dim;

if (type_ == RotaryScalingType::kLinear) {
rope_scaling_factor_ /= param.rope_scaling_factor;
}
else if (type_ == RotaryScalingType::kLlama3) {
const double PI = 3.14159265358979323846;
float inv_diff_freq_factor = 1.0 / (param.high_freq_factor - param.low_freq_factor);
llama3_inv_scaling_factor_ = 1.0 / param.rope_scaling_factor;
llama3_alpha_ = param.original_max_position_embeddings / (2 * PI) * inv_diff_freq_factor;
llama3_beta_ = param.low_freq_factor * inv_diff_freq_factor;
}
else if (type_ == RotaryScalingType::kYarn) {
const double PI = 3.14159265358979323846;
auto find_correction_dim = [&](float num_rotations) {
return (param.rotary_embedding_dim * std::log(param.max_position_embeddings / (num_rotations * 2 * PI)))
/ (2 * std::log(param.rotary_embedding_base));
};
auto find_correction_range = [&](float low_rot, float high_rot, float& low, float& high) {
low = std::floor(find_correction_dim(low_rot));
high = std::ceil(find_correction_dim(high_rot));
low = std::max(low, 0.f);
high = std::min(high, param.rotary_embedding_dim - 1.f);
};
float low, high;
find_correction_range(param.beta_fast, param.beta_slow, low, high);
if (low == high) {
high += 0.01f;
switch (type_) {
case RotaryScalingType::kDefault:
break;
case RotaryScalingType::kLinear:
inv_factor_ = 1.0f / param.rope.factor;
break;
case RotaryScalingType::kDynamic:
inv_factor_ = param.rope.factor;
break;
case RotaryScalingType::kYarn: {
const double PI = 3.14159265358979323846;
auto find_correction_dim = [&](float num_rotations) {
return (param.rope.dim * std::log(param.rope.max_position_embeddings / (num_rotations * 2 * PI)))
/ (2 * std::log(param.rope.base));
};
auto find_correction_range = [&](float low_rot, float high_rot, float& low, float& high) {
low = std::floor(find_correction_dim(low_rot));
high = std::ceil(find_correction_dim(high_rot));
low = std::max(low, 0.f);
high = std::min(high, param.rope.dim - 1.f);
};
float low, high;
find_correction_range(param.rope.yarn.beta_fast, param.rope.yarn.beta_slow, low, high);
if (low == high) {
high += 0.01f;
}
yarn_.yarn_ramp_inv_factor_div_2 = 1.0 / (high - low) / 2.0;
yarn_.yarn_ramp_inv_factor_mul_min = 1.0 / (high - low) * low;
yarn_.yarn_inv_scaling_factor = (1 - 1.0 / param.rope.factor);
yarn_.attention_factor = param.rope.yarn.attention_factor;
break;
}
case RotaryScalingType::kLlama3: {
const double PI = 3.14159265358979323846;
float inv_diff_freq_factor = 1.0 / (param.rope.llama3.high_freq_factor - param.rope.llama3.low_freq_factor);
llama3_.llama3_inv_scaling_factor = 1.0 / param.rope.factor;
llama3_.llama3_alpha = param.rope.llama3.original_max_position_embeddings / (2 * PI) * inv_diff_freq_factor;
llama3_.llama3_beta = param.rope.llama3.low_freq_factor * inv_diff_freq_factor;
break;
}
yarn_ramp_inv_factor_div_2_ = 1.0 / (high - low) / 2.0;
yarn_ramp_inv_factor_mul_min_ = 1.0 / (high - low) * low;
yarn_inv_scaling_factor_ = (1 - 1.0 / param.rope_scaling_factor);
attention_factor_ = param.attention_factor;
default:
FT_CHECK(0);
break;
}
}

Expand All @@ -188,7 +197,7 @@ void RotaryEmbeddingV2::forward(const RotaryEmbeddingV2Params& params)
params.token_num,
params.batch_size,
dim_,
rope_scaling_factor_,
inv_factor_,
cos_sin_);
break;
case RotaryScalingType::kLlama3:
Expand All @@ -198,9 +207,9 @@ void RotaryEmbeddingV2::forward(const RotaryEmbeddingV2Params& params)
params.token_num,
params.batch_size,
dim_,
llama3_inv_scaling_factor_,
llama3_alpha_,
llama3_beta_,
llama3_.llama3_inv_scaling_factor,
llama3_.llama3_alpha,
llama3_.llama3_beta,
cos_sin_);
break;
case RotaryScalingType::kYarn:
Expand All @@ -210,14 +219,12 @@ void RotaryEmbeddingV2::forward(const RotaryEmbeddingV2Params& params)
params.token_num,
params.batch_size,
dim_,
yarn_ramp_inv_factor_div_2_,
yarn_ramp_inv_factor_mul_min_,
yarn_inv_scaling_factor_,
attention_factor_,
yarn_.yarn_ramp_inv_factor_div_2,
yarn_.yarn_ramp_inv_factor_mul_min,
yarn_.yarn_inv_scaling_factor,
yarn_.attention_factor,
cos_sin_);
break;
case RotaryScalingType::kMrope:
FT_CHECK(0);
default:
FT_CHECK(0);
}
Expand Down
49 changes: 23 additions & 26 deletions src/turbomind/models/llama/rotary_emb.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,7 @@

namespace turbomind {

enum class RotaryScalingType
{
kDefault,
kLinear,
kDynamic,
kYarn,
kLlama3,
kMrope
};
RotaryScalingType GetRoPEType(const std::string& type);

struct RotaryEmbeddingV2Params {
float* rope_theta;
Expand All @@ -23,6 +15,19 @@ struct RotaryEmbeddingV2Params {
int token_num;
};

struct InnerYarnRopeParam {
float attention_factor;
float yarn_ramp_inv_factor_div_2;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can remove the prefix "yarn_"

float yarn_ramp_inv_factor_mul_min;
float yarn_inv_scaling_factor;
};

struct InnerLlama3RopeParam {
float llama3_inv_scaling_factor;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the prefix "llama3_" can be removed

float llama3_alpha;
float llama3_beta;
};

struct RotaryEmbeddingV2 {

RotaryEmbeddingV2(const AttentionParam& param, cudaStream_t stream, IAllocator* allocator);
Expand All @@ -38,28 +43,20 @@ struct RotaryEmbeddingV2 {

void forward(const RotaryEmbeddingV2Params& params);

RotaryScalingType type_;
cudaStream_t const stream_;
IAllocator* const allocator_;

int dim_;
RotaryScalingType type_;
float inv_factor_{1.0};

union {
InnerYarnRopeParam yarn_;
InnerLlama3RopeParam llama3_;
};

// output
float* cos_sin_; // num_token x dim, (cos, sin, ...)

int dim_;
// default, linear, dynamic
float attention_factor_;
float rope_scaling_factor_;
float inv_scale_factor_;
// llama3
float llama3_inv_scaling_factor_;
float llama3_alpha_;
float llama3_beta_;
// yarn
float yarn_ramp_inv_factor_div_2_;
float yarn_ramp_inv_factor_mul_min_;
float yarn_inv_scaling_factor_;
// mrope
int3 mrope_section_;
};

}; // namespace turbomind
4 changes: 2 additions & 2 deletions src/turbomind/models/llama/unified_attention_layer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -313,8 +313,8 @@ inline void UnifiedAttentionLayer<T>::forward(TensorMap* outputs, const TensorMa
}

// rope
params.rotary_embedding_dim = param_.rotary_embedding_dim;
params.max_position_embeddings = param_.max_position_embeddings;
params.rotary_embedding_dim = param_.rope.dim;
params.max_position_embeddings = param_.rope.max_position_embeddings;
params.cos_sin = cos_sin;
params.use_logn_attn = param_.use_logn_attn;

Expand Down
6 changes: 1 addition & 5 deletions src/turbomind/models/llama/unified_decoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,7 @@ void UnifiedDecoder<T>::forwardSelfAttn(T* attn_io,
inputs.insert("cu_k_len", {MEMORY_GPU, TYPE_INT32, {batch_size + 1}, cu_k_len_});
inputs.insert("h_cu_q_len", {MEMORY_CPU, TYPE_INT32, {batch_size + 1}, h_cu_q_len_});
inputs.insert("h_cu_k_len", {MEMORY_CPU, TYPE_INT32, {batch_size + 1}, h_cu_k_len_});

if (rotary_emb_) {
inputs.insert("cos_sin",
{MEMORY_GPU, TYPE_FP32, {token_num, (size_t)rotary_emb_->dim_}, rotary_emb_->cos_sin_});
}
inputs.insert("cos_sin", {MEMORY_GPU, TYPE_FP32, {token_num, (size_t)rotary_emb_->dim_}, rotary_emb_->cos_sin_});

TensorMap outputs(*_outputs);
outputs.insert("hidden_features", {MEMORY_GPU, dtype_, {token_num, hidden_units_}, attn_io});
Expand Down
41 changes: 22 additions & 19 deletions src/turbomind/triton_backend/llama/LlamaTritonModel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -140,10 +140,10 @@ void LlamaTritonModel<T>::handleMissingParams()
(int)model_param_.vocab_size);
}

if (!attn_param_.max_position_embeddings) {
attn_param_.max_position_embeddings = 2048;
if (!attn_param_.rope.max_position_embeddings) {
attn_param_.rope.max_position_embeddings = 2048;
TM_LOG_WARNING("[LlamaTritonModel] `max_position_embeddings` is not set, default to %d.",
(int)attn_param_.max_position_embeddings);
(int)attn_param_.rope.max_position_embeddings);
}

if (!engine_param_.max_batch_size) {
Expand All @@ -153,7 +153,7 @@ void LlamaTritonModel<T>::handleMissingParams()
}

if (!engine_param_.session_len) {
engine_param_.session_len = attn_param_.max_position_embeddings;
engine_param_.session_len = attn_param_.rope.max_position_embeddings;
TM_LOG_WARNING("[LlamaTritonModel] `session_len` is not set, default to %d.", (int)engine_param_.session_len);
}

Expand Down Expand Up @@ -277,22 +277,25 @@ LlamaTritonModel<T>::LlamaTritonModel(size_t tensor_para_size,
model_param_.attn_bias = model_reader["attn_bias"].as<int>(0);
model_param_.group_size = model_reader["group_size"].as<int>(0);

attn_param_.softmax_scale = attention_reader["softmax_scale"].as<float>(0);
attn_param_.use_logn_attn = attention_reader["use_logn_attn"].as<int>(0);
// rotary embedding parameters
attn_param_.rotary_embedding_dim = attention_reader["rotary_embedding"].as<int>();
attn_param_.rotary_embedding_base = attention_reader["rope_theta"].as<float>(10000.0f);
attn_param_.softmax_scale = attention_reader["softmax_scale"].as<float>(0);
attn_param_.attention_factor = attention_reader["attention_factor"].as<float>(-1.f);
attn_param_.beta_fast = attention_reader["beta_fast"].as<float>(32.f);
attn_param_.beta_slow = attention_reader["beta_slow"].as<float>(1.f);
attn_param_.rope_scaling_type = attention_reader["rope_scaling_type"].as<std::string>("");
attn_param_.rope_scaling_factor = attention_reader["rope_scaling_factor"].as<float>(0.f);
attn_param_.low_freq_factor = attention_reader["low_freq_factor"].as<float>(1.0);
attn_param_.high_freq_factor = attention_reader["high_freq_factor"].as<float>(1.0);
attn_param_.max_position_embeddings = attention_reader["max_position_embeddings"].as<int>(0);
attn_param_.use_dynamic_ntk = attention_reader["use_dynamic_ntk"].as<int>(0);
attn_param_.use_logn_attn = attention_reader["use_logn_attn"].as<int>(0);

attn_param_.original_max_position_embeddings = attention_reader["original_max_position_embeddings"].as<int>(0);
attn_param_.rope.type = GetRoPEType(attention_reader["rope_scaling_type"].as<std::string>(""));
attn_param_.rope.dim = attention_reader["rotary_embedding"].as<int>();
attn_param_.rope.base = attention_reader["rope_theta"].as<float>(10000.0f);
attn_param_.rope.max_position_embeddings = attention_reader["max_position_embeddings"].as<int>(0);
attn_param_.rope.factor = attention_reader["rope_scaling_factor"].as<float>(0.f);
if (attn_param_.rope.type == RotaryScalingType::kYarn) {
attn_param_.rope.yarn.attention_factor = attention_reader["attention_factor"].as<float>(-1.f);
attn_param_.rope.yarn.beta_fast = attention_reader["beta_fast"].as<float>(32.f);
attn_param_.rope.yarn.beta_slow = attention_reader["beta_slow"].as<float>(1.f);
}
else if (attn_param_.rope.type == RotaryScalingType::kLlama3) {
attn_param_.rope.llama3.low_freq_factor = attention_reader["low_freq_factor"].as<float>(1.0);
attn_param_.rope.llama3.high_freq_factor = attention_reader["high_freq_factor"].as<float>(1.0);
attn_param_.rope.llama3.original_max_position_embeddings =
attention_reader["original_max_position_embeddings"].as<int>(0);
}

engine_param_.max_batch_size = engine_reader["max_batch_size"].as<int>(0);
engine_param_.max_prefill_token_num = engine_reader["max_prefill_token_num"].as<int>(0);
Expand Down
Loading