diff --git a/docs/en/multi_modal/qwen2_vl.md b/docs/en/multi_modal/qwen2_vl.md index 8b59f8454..d26c89fd9 100644 --- a/docs/en/multi_modal/qwen2_vl.md +++ b/docs/en/multi_modal/qwen2_vl.md @@ -5,7 +5,7 @@ LMDeploy supports the following Qwen-VL series of models, which are detailed in | Model | Size | Supported Inference Engine | | :----------: | :----: | :------------------------: | | Qwen-VL-Chat | - | TurboMind, Pytorch | -| Qwen2-VL | 2B, 7B | PyTorch | +| Qwen2-VL | 2B-72B | TurboMind, PyTorch | The next chapter demonstrates how to deploy an Qwen-VL model using LMDeploy, with [Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct) as an example. diff --git a/docs/en/supported_models/supported_models.md b/docs/en/supported_models/supported_models.md index 1f344e78b..22e887ce5 100644 --- a/docs/en/supported_models/supported_models.md +++ b/docs/en/supported_models/supported_models.md @@ -19,6 +19,7 @@ The following tables detail the models supported by LMDeploy's TurboMind engine | Qwen | 1.8B - 72B | LLM | Yes | Yes | Yes | Yes | | Qwen1.5 | 1.8B - 110B | LLM | Yes | Yes | Yes | Yes | | Qwen2 | 1.5B - 72B | LLM | Yes | Yes | Yes | Yes | +| QWen2-VL | 2B- 72B | MLLM | Yes | Yes | Yes | Yes | | Mistral | 7B | LLM | Yes | Yes | Yes | Yes | | Mixtral | 8x7B, 8x22B | LLM | Yes | Yes | Yes | Yes | | Qwen-VL | 7B | MLLM | Yes | Yes | Yes | Yes | diff --git a/docs/zh_cn/multi_modal/qwen2_vl.md b/docs/zh_cn/multi_modal/qwen2_vl.md index f62d2de74..6ec2164d5 100644 --- a/docs/zh_cn/multi_modal/qwen2_vl.md +++ b/docs/zh_cn/multi_modal/qwen2_vl.md @@ -5,7 +5,7 @@ LMDeploy 支持 Qwen-VL 系列模型,具体如下: | Model | Size | Supported Inference Engine | | :----------: | :----: | :------------------------: | | Qwen-VL-Chat | - | TurboMind, Pytorch | -| Qwen2-VL | 2B, 7B | PyTorch | +| Qwen2-VL | 2B-72B | TurboMind,PyTorch | 本文将以[Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct)为例,演示使用 LMDeploy 部署 Qwen2-VL 系列模型的方法 diff --git a/docs/zh_cn/supported_models/supported_models.md b/docs/zh_cn/supported_models/supported_models.md index ac061cf1a..26ef920c9 100644 --- a/docs/zh_cn/supported_models/supported_models.md +++ b/docs/zh_cn/supported_models/supported_models.md @@ -19,6 +19,7 @@ | Qwen | 1.8B - 72B | LLM | Yes | Yes | Yes | Yes | | Qwen1.5 | 1.8B - 110B | LLM | Yes | Yes | Yes | Yes | | Qwen2 | 1.5B - 72B | LLM | Yes | Yes | Yes | Yes | +| QWen2-VL | 2B- 72B | MLLM | Yes | Yes | Yes | Yes | | Mistral | 7B | LLM | Yes | Yes | Yes | Yes | | Mixtral | 8x7B, 8x22B | LLM | Yes | Yes | Yes | Yes | | Qwen-VL | 7B | MLLM | Yes | Yes | Yes | Yes | diff --git a/lmdeploy/turbomind/deploy/config.py b/lmdeploy/turbomind/deploy/config.py index 7e8ebf7b4..2839c5359 100644 --- a/lmdeploy/turbomind/deploy/config.py +++ b/lmdeploy/turbomind/deploy/config.py @@ -2,6 +2,7 @@ import inspect import json from dataclasses import asdict, fields +from typing import List # use pydantic.dataclasses.dataclass to check data type from pydantic.dataclasses import dataclass @@ -73,6 +74,7 @@ class AttentionConfig: high_freq_factor: float = 1.0 beta_fast: float = 32.0 beta_slow: float = 1.0 + mrope_section: List[int] = None use_logn_attn: int = 0 cache_block_seq_len: int = 64 diff --git a/lmdeploy/turbomind/deploy/source_model/qwen.py b/lmdeploy/turbomind/deploy/source_model/qwen.py index 0ec0586a3..d330fe800 100644 --- a/lmdeploy/turbomind/deploy/source_model/qwen.py +++ b/lmdeploy/turbomind/deploy/source_model/qwen.py @@ -119,4 +119,13 @@ def tokenizer_info(self): def model_info(self): cfg = super().model_info() cfg['attn_bias'] = 1 + params_path = osp.join(self.model_path, 'config.json') + with open(params_path) as f: + config = json.load(f) + rope_scaling = config['rope_scaling'] + if rope_scaling is not None: + if rope_scaling.get('type', '') == 'mrope': + selection = rope_scaling['mrope_section'] + cfg['rope_scaling_type'] = 'mrope' + cfg['mrope_section'] = selection return cfg diff --git a/lmdeploy/turbomind/supported_models.py b/lmdeploy/turbomind/supported_models.py index 8a1f5e731..db99c44ce 100644 --- a/lmdeploy/turbomind/supported_models.py +++ b/lmdeploy/turbomind/supported_models.py @@ -20,6 +20,8 @@ QWenLMHeadModel='qwen', # Qwen2 Qwen2ForCausalLM='qwen2', + # # Qwen2-VL + Qwen2VLForConditionalGeneration='qwen2', # mistral MistralForCausalLM='llama', # llava diff --git a/lmdeploy/turbomind/turbomind.py b/lmdeploy/turbomind/turbomind.py index 05bc3e400..f74707f5d 100644 --- a/lmdeploy/turbomind/turbomind.py +++ b/lmdeploy/turbomind/turbomind.py @@ -515,6 +515,8 @@ def prepare_inputs(self, gen_config: GenerationConfig, input_embeddings=None, input_embedding_ranges=None, + mrope_position_ids=None, + mrope_position_delta=None, sequence_start: bool = True, sequence_end: bool = False, step=0, @@ -572,6 +574,18 @@ def _broadcast_np(data, dtype, shape=(batch_size, )): inputs['input_embeddings'] = input_embeddings inputs['input_embedding_ranges'] = input_embedding_ranges + if mrope_position_ids is not None: + assert isinstance(mrope_position_ids, torch.Tensor) + assert isinstance(mrope_position_delta, torch.Tensor) + assert input_lengths.size(0) == 1 + assert mrope_position_ids.size(-1) == input_ids.size(-1) + mrope_position_ids = pad_sequence([mrope_position_ids], + batch_first=True, + padding_value=-1).transpose( + 1, 2).int().reshape(1, -1) + inputs['mrope_position_ids'] = mrope_position_ids + inputs['mrope_position_delta'] = mrope_position_delta + if gen_config.min_new_tokens is not None: inputs['min_length'] = _broadcast_np(gen_config.min_new_tokens, np.int32) @@ -611,6 +625,8 @@ async def async_stream_infer(self, input_ids, input_embeddings=None, input_embedding_ranges=None, + mrope_position_ids=None, + mrope_position_delta=None, sequence_start: bool = True, sequence_end: bool = False, step=0, @@ -648,6 +664,8 @@ async def async_stream_infer(self, input_ids=input_ids, input_embeddings=input_embeddings, input_embedding_ranges=input_embedding_ranges, + mrope_position_ids=mrope_position_ids, + mrope_position_delta=mrope_position_delta, sequence_start=sequence_start, sequence_end=sequence_end, step=step, @@ -734,6 +752,8 @@ def stream_infer(self, input_ids, input_embeddings=None, input_embedding_ranges=None, + mrope_position_ids=None, + mrope_position_delta=None, sequence_start: bool = True, sequence_end: bool = False, step=0, @@ -766,6 +786,8 @@ def stream_infer(self, input_ids=input_ids, input_embeddings=input_embeddings, input_embedding_ranges=input_embedding_ranges, + mrope_position_ids=mrope_position_ids, + mrope_position_delta=mrope_position_delta, sequence_start=sequence_start, sequence_end=sequence_end, step=step, diff --git a/src/turbomind/kernels/attention/attention_params.h b/src/turbomind/kernels/attention/attention_params.h index b6dfaa596..90997f25f 100644 --- a/src/turbomind/kernels/attention/attention_params.h +++ b/src/turbomind/kernels/attention/attention_params.h @@ -67,10 +67,16 @@ struct AttentionParams { float llama3_inv_scaling_factor; float llama3_alpha; float llama3_beta; - // the following are use by yarn + // the following are used by yarn float yarn_ramp_inv_factor_div_2; float yarn_ramp_inv_factor_mul_min; float yarn_inv_scaling_factor; + // the following are used by qwen2-vl + int3 mrope_section; + int* mrope_position_ids; // 3 x session_len_ + int mrope_offset; // session_len_ + int* mrope_position_delta; + int* mrope_position_length; // log(n) attention bool use_logn_attn; diff --git a/src/turbomind/kernels/attention/attention_universal.h b/src/turbomind/kernels/attention/attention_universal.h index 5fb583bd1..eb67b3390 100644 --- a/src/turbomind/kernels/attention/attention_universal.h +++ b/src/turbomind/kernels/attention/attention_universal.h @@ -223,6 +223,14 @@ struct AttentionUniversal { ApplyBias(vec_Q, vec_K, vec_V, params, head_idx, kv_head_idx, offset); + int* mrope_ids = nullptr; + int mrope_length = 0; + int mrope_delta = 0; + if (params.mrope_position_ids != nullptr) { + mrope_ids = params.mrope_position_ids + batch_idx * 3 * params.mrope_offset; + mrope_length = params.mrope_position_length[batch_idx]; + mrope_delta = params.mrope_position_delta[batch_idx]; + } const float rope_base = params.rope_theta ? params.rope_theta[batch_idx] : params.rotary_embedding_base; PRAGMA_UNROLL for (int c = 0; c < ITER_C; ++c) { @@ -239,6 +247,10 @@ struct AttentionUniversal { params.yarn_ramp_inv_factor_mul_min, params.yarn_inv_scaling_factor, params.attention_scaling, + params.mrope_section, + mrope_ids, + mrope_length, + mrope_delta, std::integral_constant{}); PRAGMA_UNROLL for (int s = 0; s < ITER_S; ++s) { diff --git a/src/turbomind/kernels/attention/kv_cache_utils_v2.cu b/src/turbomind/kernels/attention/kv_cache_utils_v2.cu index 9f28a17b8..1b585cd74 100644 --- a/src/turbomind/kernels/attention/kv_cache_utils_v2.cu +++ b/src/turbomind/kernels/attention/kv_cache_utils_v2.cu @@ -31,6 +31,11 @@ __global__ void __launch_bounds__(128) ProcessKV_v2(char** blocks, float yarn_ramp_inv_factor_mul_min, float yarn_inv_scaling_factor, float attention_scaling, + int3 mrope_section, + int* mrope_position_ids, + int mrope_offset, + int* mrope_position_delta, + int* mrope_position_length, int64_t stride_b, int64_t stride_c, int64_t stride_h, @@ -125,6 +130,14 @@ __global__ void __launch_bounds__(128) ProcessKV_v2(char** blocks, } if (rope_base) { + int* mrope_ids = nullptr; + int mrope_length = 0; + int mrope_delta = 0; + if (mrope_position_ids != nullptr) { + mrope_ids = mrope_position_ids + batch_idx * 3 * mrope_offset; + mrope_length = mrope_position_length[batch_idx]; + mrope_delta = mrope_position_delta[batch_idx]; + } float base = rope_base[batch_idx]; PRAGMA_UNROLL for (int c = 0; c < ITER_C; ++c) { @@ -141,6 +154,10 @@ __global__ void __launch_bounds__(128) ProcessKV_v2(char** blocks, yarn_ramp_inv_factor_mul_min, yarn_inv_scaling_factor, attention_scaling, + mrope_section, + mrope_ids, + mrope_length, + mrope_delta, std::integral_constant{}); PRAGMA_UNROLL for (int s = 0; s < ITER_S; ++s) { @@ -222,6 +239,11 @@ void invokeProcessKV_v2(char** blocks, float yarn_ramp_inv_factor_mul_min, float yarn_inv_scaling_factor, float attention_scaling, + int3 mrope_section, + int* mrope_position_ids, + int mrope_offset, + int* mrope_position_delta, + int* mrope_position_length, int64_t stride_b, int64_t stride_c, int64_t stride_h, @@ -268,6 +290,11 @@ void invokeProcessKV_v2(char** blocks, yarn_ramp_inv_factor_mul_min, yarn_inv_scaling_factor, attention_scaling, + mrope_section, + mrope_position_ids, + mrope_offset, + mrope_position_delta, + mrope_position_length, stride_b, stride_c, stride_h, @@ -307,6 +334,11 @@ void invokeProcessKV_v2(char** blocks, float yarn_ramp_inv_factor_mul_min, \ float yarn_inv_scaling_factor, \ float attention_scaling, \ + int3 mrope_section, \ + int* mrope_position_ids, \ + int mrope_offset, \ + int* mrope_position_delta, \ + int* mrope_position_length, \ int64_t stride_b, \ int64_t stride_c, \ int64_t stride_h, \ @@ -342,6 +374,11 @@ __global__ void __launch_bounds__(128) flattenKV_v2(T* k, float yarn_ramp_inv_factor_mul_min, float yarn_inv_scaling_factor, float attention_scaling, + int3 mrope_section, + int* mrope_position_ids, + int mrope_offset, + int* mrope_position_delta, + int* mrope_position_length, int64_t stride_b, int64_t stride_c, int64_t stride_h, @@ -419,6 +456,14 @@ __global__ void __launch_bounds__(128) flattenKV_v2(T* k, } if (rope_base) { + int* mrope_ids = nullptr; + int mrope_length = 0; + int mrope_delta = 0; + if (mrope_position_ids != nullptr) { + mrope_ids = mrope_position_ids + batch_idx * 3 * mrope_offset; + mrope_length = mrope_position_length[batch_idx]; + mrope_delta = mrope_position_delta[batch_idx]; + } float base = rope_base[batch_idx]; PRAGMA_UNROLL for (int c = 0; c < ITER_C; ++c) { @@ -435,6 +480,10 @@ __global__ void __launch_bounds__(128) flattenKV_v2(T* k, yarn_ramp_inv_factor_mul_min, yarn_inv_scaling_factor, attention_scaling, + mrope_section, + mrope_ids, + mrope_length, + mrope_delta, std::integral_constant{}); PRAGMA_UNROLL for (int s = 0; s < ITER_S; ++s) { @@ -477,6 +526,11 @@ void invokeFlattenKV_v2(T* k, float yarn_ramp_inv_factor_mul_min, float yarn_inv_scaling_factor, float attention_scaling, + int3 mrope_section, + int* mrope_position_ids, + int mrope_offset, + int* mrope_position_delta, + int* mrope_position_length, int64_t stride_b, int64_t stride_c, int64_t stride_h, @@ -520,6 +574,11 @@ void invokeFlattenKV_v2(T* k, yarn_ramp_inv_factor_mul_min, yarn_inv_scaling_factor, attention_scaling, + mrope_section, + mrope_position_ids, + mrope_offset, + mrope_position_delta, + mrope_position_length, stride_b, stride_c, stride_h, @@ -556,6 +615,11 @@ void invokeFlattenKV_v2(T* k, float yarn_ramp_inv_factor_mul_min, \ float yarn_inv_scaling_factor, \ float attention_scaling, \ + int3 mrope_section, \ + int* mrope_position_ids, \ + int mrope_offset, \ + int* mrope_position_delta, \ + int* mrope_position_length, \ int64_t stride_b, \ int64_t stride_c, \ int64_t stride_h, \ diff --git a/src/turbomind/kernels/attention/kv_cache_utils_v2.h b/src/turbomind/kernels/attention/kv_cache_utils_v2.h index fe45ad7be..63beb36cf 100644 --- a/src/turbomind/kernels/attention/kv_cache_utils_v2.h +++ b/src/turbomind/kernels/attention/kv_cache_utils_v2.h @@ -27,6 +27,11 @@ void invokeProcessKV_v2(char** blocks, float yarn_ramp_inv_factor_mul_min, float yarn_inv_scaling_factor, float attention_scaling, + int3 mrope_section, + int* mrope_position_ids, + int mrope_offset, + int* mrope_position_delta, + int* mrope_position_length, int64_t stride_b, int64_t stride_c, int64_t stride_h, @@ -62,6 +67,11 @@ void invokeProcessKV_v2_(const AttentionParams& params) params.yarn_ramp_inv_factor_mul_min, params.yarn_inv_scaling_factor, params.attention_scaling, + params.mrope_section, + params.mrope_position_ids, + params.mrope_offset, + params.mrope_position_delta, + params.mrope_position_length, 0, // stride b params.stride / params.size_per_head, // stride c 1, // stride h @@ -93,6 +103,11 @@ void invokeFlattenKV_v2(T* k, float yarn_ramp_inv_factor_mul_min, float yarn_inv_scaling_factor, float attention_scaling, + int3 mrope_section, + int* mrope_position_ids, + int mrope_offset, + int* mrope_position_delta, + int* mrope_position_length, int64_t stride_b, int64_t stride_c, int64_t stride_h, @@ -127,6 +142,11 @@ void invokeFlattenKV_v2_(const AttentionParams& params, int sum_k_len) params.yarn_ramp_inv_factor_mul_min, params.yarn_inv_scaling_factor, params.attention_scaling, + params.mrope_section, + params.mrope_position_ids, + params.mrope_offset, + params.mrope_position_delta, + params.mrope_position_length, 0, 1, 2 * sum_k_len, diff --git a/src/turbomind/kernels/attention/rotary_embedding.h b/src/turbomind/kernels/attention/rotary_embedding.h index 8e09da22c..51e74a111 100644 --- a/src/turbomind/kernels/attention/rotary_embedding.h +++ b/src/turbomind/kernels/attention/rotary_embedding.h @@ -76,19 +76,32 @@ struct FastRoPE { bool is_valid_; float attention_scaling_; - __device__ FastRoPE(int idx, - D dims, - float base, - float ti_scale, - float factor, - float llama3_inv_scaling_factor, - float llama3_alpha, - float llama3_beta, - float yarn_ramp_inv_factor_div_2, - float yarn_ramp_inv_factor_mul_min, - float yarn_inv_scaling_factor, - float attention_scaling, - std::integral_constant) + int3 mrope_selection_; + const int* mrope_position_ids_{}; + const int mrope_position_length_; + const int mrope_position_delta_; + + __device__ FastRoPE(int idx, + D dims, + float base, + float ti_scale, + float factor, + float llama3_inv_scaling_factor, + float llama3_alpha, + float llama3_beta, + float yarn_ramp_inv_factor_div_2, + float yarn_ramp_inv_factor_mul_min, + float yarn_inv_scaling_factor, + float attention_scaling, + int3 mrope_section, + const int* mrope_ids, + int mrope_length, + int mrope_delta, + std::integral_constant): + mrope_selection_(mrope_section), + mrope_position_ids_(mrope_ids), + mrope_position_length_(mrope_length), + mrope_position_delta_(mrope_delta) { is_valid_ = idx < dims; attention_scaling_ = attention_scaling; @@ -126,13 +139,60 @@ struct FastRoPE { inv_freq_[i / 2] = freq - freq * alpha * yarn_inv_scaling_factor; } } + if (mrope_position_ids_ != nullptr) { + mrope_selection_.x = mrope_selection_.x * 2; + mrope_selection_.y = mrope_selection_.y * 2 + mrope_selection_.x; + mrope_selection_.z = mrope_selection_.z * 2 + mrope_selection_.y; + } } template __device__ void apply(Array& x, float timestep) { + if (mrope_position_ids_) { + return apply_mrope(x, timestep); + } + + PRAGMA_UNROLL + for (int i = 0; i < N; i += 2) { + float c, s; + sincosf(timestep * inv_freq_[i / 2], &s, &c); + s *= attention_scaling_; + c *= attention_scaling_; + float tmp0 = c * (float)x[i] - s * (float)x[i + 1]; + float tmp1 = c * (float)x[i + 1] + s * (float)x[i]; + if (is_valid_) { + x[i] = (T)tmp0; + x[i + 1] = (T)tmp1; + } + } + } + + template + __device__ void apply_mrope(Array& x, float timestep) + { + int p1, p2, p3; + if (timestep < mrope_position_length_) { + const int* p = mrope_position_ids_ + 3 * (int)timestep; + p1 = *p; + p2 = *(p + 1); + p3 = *(p + 2); + } + else { + p1 = p2 = p3 = (int)timestep - mrope_position_delta_; + } + PRAGMA_UNROLL for (int i = 0; i < N; i += 2) { + if (i < mrope_selection_.x) { + timestep = (float)p1; + } + else if (i < mrope_selection_.y) { + timestep = (float)p2; + } + else { + timestep = (float)p3; + } float c, s; sincosf(timestep * inv_freq_[i / 2], &s, &c); s *= attention_scaling_; diff --git a/src/turbomind/kernels/attention/test_attention.cu b/src/turbomind/kernels/attention/test_attention.cu index c6d7b4063..7a6e3bbe1 100644 --- a/src/turbomind/kernels/attention/test_attention.cu +++ b/src/turbomind/kernels/attention/test_attention.cu @@ -153,10 +153,15 @@ void TestBlocks(const thrust::universal_vector& k_cache, // [B, H, S, 0., 1.0, 1.0, - 0.0, - 0.0, - 0.0, - 1.0, + 0.0, // yarn_ramp_inv_factor_div_2 + 0.0, // yarn_ramp_inv_factor_mul_min + 0.0, // yarn_inv_scaling_factor + 1.0, // attention_scaling + {}, // mrope_section + nullptr, // mrope_position_ids + 0, // mrope_offset + nullptr, // mrope_position_delta + nullptr, // mrope_position_length 2 * head_num * seq_len, 0, seq_len, @@ -191,6 +196,11 @@ void TestBlocks(const thrust::universal_vector& k_cache, // [B, H, S, 0.0, 0.0, 1.0, + {}, // mrope_section + nullptr, // mrope_position_ids + 0, // mrope_offset + nullptr, // mrope_position_delta + nullptr, // mrope_position_length 2 * head_num * seq_len, 0, seq_len, @@ -555,6 +565,11 @@ int test_attention() 0.0, 0.0, 1.0, + {}, // mrope_section + nullptr, // mrope_position_ids + 0, // mrope_offset + nullptr, // mrope_position_delta + nullptr, // mrope_position_length KvHeadNum * kContextLen, 0, kContextLen, diff --git a/src/turbomind/models/llama/LlamaBatch.cc b/src/turbomind/models/llama/LlamaBatch.cc index 4138174e5..2de5836cd 100644 --- a/src/turbomind/models/llama/LlamaBatch.cc +++ b/src/turbomind/models/llama/LlamaBatch.cc @@ -347,6 +347,18 @@ void LlamaBatch::ProcessInferRequests(const Requests& requests) } } + if (attn_param_.rope_scaling_type == "mrope") { + if (!r->start_flag) { + TM_LOG_ERROR("M-ROPE doesn't support interactive chat."); + } + const int* position_ids_ptr = r->inputs.getPtr("mrope_position_ids"); + const int* position_delta_ptr = r->inputs.getPtr("mrope_position_delta"); + const int* length_ptr = r->inputs.getPtr("input_lengths"); + Copy(position_ids_ptr, input_length * 3, state.mrope.position_ids + session_len_ * idx * 3); + Copy(position_delta_ptr, 1, state.mrope.position_delta + idx); + Copy(length_ptr, 1, state.mrope.length + idx); + } + const int request_output_len = state.requests[idx]->inputs.getVal("request_output_len"); state.seq_len_limit[idx] = state.h_context_length[idx] + request_output_len; // `length_criterion` sets finish flag when step >= seq_limit_len, however when step == seq_limit_len @@ -687,6 +699,14 @@ void LlamaBatch::CopyState(const std::vectoroutput_ids, d->output_ids, session_len_}, std::tuple{s->curand_state, d->curand_state, 1}); + + if (attn_param_.rope_scaling_type == "mrope") { + IndexedCopy(s_idx, + d_idx, + std::tuple{s->mrope.position_ids, d->mrope.position_ids, session_len_ * 3}, + std::tuple{s->mrope.position_delta, d->mrope.position_delta, 1}, + std::tuple{s->mrope.length, d->mrope.length, 1}); + } } for (const auto& [s, d, si, di] : desc) { @@ -810,6 +830,14 @@ void LlamaBatch::AllocatePersistantBuffer(size_t max_batch_size, int cache_bl s.output_ids = (int*)allocator_->reMalloc(s.output_ids, sizeof(int) * max_batch_size * session_len_, true); s.curand_state = (curandState_t*)allocator_->reMalloc(s.curand_state, sizeof(curandState_t) * max_batch_size, true); + + if (attn_param_.rope_scaling_type == "mrope") { + s.mrope.position_ids = + (int*)allocator_->reMalloc(s.mrope.position_ids, sizeof(int) * 3 * max_batch_size * session_len_); + s.mrope.position_delta = (int*)allocator_->reMalloc(s.mrope.position_delta, sizeof(int) * max_batch_size); + s.mrope.length = (int*)allocator_->reMalloc(s.mrope.length, sizeof(int) * max_batch_size); + s.mrope.session_len = session_len_; + } } const size_t max_batch_block_count = @@ -919,6 +947,12 @@ void LlamaBatch::FreeBuffer() allocator_->free((void**)&s.h_rope_theta, true); allocator_->free((void**)&s.output_ids); allocator_->free((void**)&s.curand_state); + + if (attn_param_.rope_scaling_type == "mrope") { + allocator_->free((void**)&s.mrope.position_ids); + allocator_->free((void**)&s.mrope.position_delta); + allocator_->free((void**)&s.mrope.length); + } } allocator_->free((void**)&h_cu_block_counts_, true); allocator_->free((void**)&h_block_ptrs_, true); @@ -966,11 +1000,13 @@ LlamaBatch::~LlamaBatch() template LlamaBatch::LlamaBatch(const EngineParam& param, + const AttentionParam& attn_param, std::unique_ptr> model, // ! This is moved std::unique_ptr> ctx, // ! This is moved std::shared_ptr state, int device_id): param_(param), + attn_param_(attn_param), shared_state_(state), max_batch_size_(param.max_batch_size), max_forward_token_num_(param.max_prefill_token_num + param.max_batch_size), @@ -1699,6 +1735,15 @@ bool LlamaBatch::Forward(GenerationState& g) } } + std::shared_ptr mrope_sp; + if (attn_param_.rope_scaling_type == "mrope") { + auto& mrope = state_->mrope; + mrope_sp.reset(new MropeInput{session_len_, + mrope.position_ids + first * 3 * session_len_, + mrope.position_delta + first, + mrope.length + first}); + } + model_->forwardUnified(decoder_output_buf_ + first * model_->hidden_units_, context_decoder_output_buf_, // temp context_decoder_input_buf_, // temp @@ -1713,6 +1758,7 @@ bool LlamaBatch::Forward(GenerationState& g) dc_batch_size, pf_batch_size, lora_mask_buf_, + mrope_sp.get(), sequences.data()); // compute logits of inputs if requested @@ -1988,6 +2034,7 @@ void LlamaBatch::tune() 0, 1, nullptr, + nullptr, nullptr); // implicit barrier for TP check_cuda_error(cudaStreamSynchronize(stream_)); diff --git a/src/turbomind/models/llama/LlamaBatch.h b/src/turbomind/models/llama/LlamaBatch.h index 9c6694899..b382af56f 100644 --- a/src/turbomind/models/llama/LlamaBatch.h +++ b/src/turbomind/models/llama/LlamaBatch.h @@ -36,6 +36,13 @@ struct Control { Request::Callback callback; }; +struct MropeInput { + int session_len; + int* position_ids{}; + int* position_delta{}; + int* length{}; +}; + struct BatchState { int* h_prompt_length; // history + input, ignore generated int* h_context_length; @@ -44,6 +51,8 @@ struct BatchState { curandState_t* curand_state; int* output_ids; // output ids in [B, S] + MropeInput mrope; + float* h_rope_theta; std::vector seq_len_limit; @@ -115,6 +124,7 @@ class LlamaBatch { const std::vector& sequences); explicit LlamaBatch(const EngineParam& param, + const AttentionParam& attn_param, std::unique_ptr> model, std::unique_ptr> ctx, std::shared_ptr state, @@ -211,7 +221,8 @@ class LlamaBatch { } private: - const EngineParam param_; + const EngineParam param_; + const AttentionParam attn_param_; const std::shared_ptr shared_state_; diff --git a/src/turbomind/models/llama/LlamaV2.cc b/src/turbomind/models/llama/LlamaV2.cc index 3d50910ad..78e8d80aa 100644 --- a/src/turbomind/models/llama/LlamaV2.cc +++ b/src/turbomind/models/llama/LlamaV2.cc @@ -213,6 +213,7 @@ void LlamaV2::forwardUnified(T* out, int dc_batch_size, int pf_batch_size, int* lora_mask, + MropeInput* mrope, const Sequence** sequences) { TM_LOG_DEBUG(__PRETTY_FUNCTION__); @@ -300,6 +301,13 @@ void LlamaV2::forwardUnified(T* out, inputs.insert({"lora_mask", {MEMORY_GPU, TYPE_INT32, {token_num}, lora_mask}}); } + if (mrope != nullptr && attn_param_.rope_scaling_type == "mrope") { + inputs.insert({"mrope_position_ids", + {MEMORY_GPU, TYPE_INT32, {bsz, (size_t)mrope->session_len, 3}, mrope->position_ids}}); + inputs.insert({"mrope_position_delta", {MEMORY_GPU, TYPE_INT32, {bsz}, mrope->position_delta}}); + inputs.insert({"mrope_position_length", {MEMORY_GPU, TYPE_INT32, {bsz}, mrope->length}}); + } + unified_decoder_->forward(&outputs, &inputs, &weights_->decoder_layer_weights); } diff --git a/src/turbomind/models/llama/LlamaV2.h b/src/turbomind/models/llama/LlamaV2.h index 6321d09d7..ed10d1ee5 100644 --- a/src/turbomind/models/llama/LlamaV2.h +++ b/src/turbomind/models/llama/LlamaV2.h @@ -82,6 +82,7 @@ class LlamaV2 { int dc_batch_size, int pf_batch_size, int* lora_mask, + MropeInput* mrope, const Sequence** sequences); void postDecodeEmbedding(float* logits, float* local_logits, const T* decoder_output, int batch_size); diff --git a/src/turbomind/models/llama/llama_params.h b/src/turbomind/models/llama/llama_params.h index 1c039ca66..d9af3d984 100644 --- a/src/turbomind/models/llama/llama_params.h +++ b/src/turbomind/models/llama/llama_params.h @@ -48,6 +48,7 @@ struct AttentionParam { float attention_factor; float beta_fast; float beta_slow; + int3 mrope_section; bool use_dynamic_ntk; bool use_logn_attn; int cache_block_seq_len; diff --git a/src/turbomind/models/llama/unified_attention_layer.cc b/src/turbomind/models/llama/unified_attention_layer.cc index 2f99b0c2c..4655a8035 100644 --- a/src/turbomind/models/llama/unified_attention_layer.cc +++ b/src/turbomind/models/llama/unified_attention_layer.cc @@ -302,14 +302,14 @@ inline void UnifiedAttentionLayer::forward(TensorMap* outputs, const TensorMa if (param_.rope_scaling_type == "linear") { params.rope_ti_scale /= param_.rope_scaling_factor; } - if (param_.rope_scaling_type == "llama3") { + else if (param_.rope_scaling_type == "llama3") { const double PI = 3.14159265358979323846; float inv_diff_freq_factor = 1.0 / (param_.high_freq_factor - param_.low_freq_factor); params.llama3_inv_scaling_factor = 1.0 / param_.rope_scaling_factor; params.llama3_alpha = param_.original_max_position_embeddings / (2 * PI) * inv_diff_freq_factor; params.llama3_beta = param_.low_freq_factor * inv_diff_freq_factor; } - if (param_.rope_scaling_type == "yarn") { + else if (param_.rope_scaling_type == "yarn") { const double PI = 3.14159265358979323846; auto find_correction_dim = [&](float num_rotations) { return (param_.rotary_embedding_dim @@ -337,6 +337,13 @@ inline void UnifiedAttentionLayer::forward(TensorMap* outputs, const TensorMa params.attention_scaling = param_.attention_factor; } } + else if (param_.rope_scaling_type == "mrope") { + params.mrope_section = param_.mrope_section; + params.mrope_position_ids = inputs->getPtr("mrope_position_ids"); + params.mrope_offset = inputs->at("mrope_position_ids").shape[1]; + params.mrope_position_delta = inputs->getPtr("mrope_position_delta"); + params.mrope_position_length = inputs->getPtr("mrope_position_length"); + } params.use_logn_attn = param_.use_logn_attn; diff --git a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc index 8db13652f..fd1ee1b00 100644 --- a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc +++ b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc @@ -276,6 +276,12 @@ LlamaTritonModel::LlamaTritonModel(size_t tensor_para_size, attn_param_.use_dynamic_ntk = attention_reader["use_dynamic_ntk"].as(0); attn_param_.use_logn_attn = attention_reader["use_logn_attn"].as(0); + if (attn_param_.rope_scaling_type == "mrope") { + std::vector mrope_section = attention_reader["mrope_section"].as>(); + ft::FT_CHECK(mrope_section.size() == 3); + attn_param_.mrope_section = {mrope_section[0], mrope_section[1], mrope_section[2]}; + } + attn_param_.original_max_position_embeddings = attention_reader["original_max_position_embeddings"].as(0); engine_param_.max_batch_size = engine_reader["max_batch_size"].as(0); @@ -369,6 +375,7 @@ std::unique_ptr> LlamaTritonModel::createSharedModelInstance( weights_[device_id]); auto engine = std::make_unique>(engine_param_, // + attn_param_, std::move(model), std::move(ctx), shared_state_,