diff --git a/src/turbomind/kernels/attention/test_attention.cu b/src/turbomind/kernels/attention/test_attention.cu index 16b7846c7..788b5d489 100644 --- a/src/turbomind/kernels/attention/test_attention.cu +++ b/src/turbomind/kernels/attention/test_attention.cu @@ -7,6 +7,7 @@ #include "src/turbomind/kernels/attention/attention_params.h" #include "src/turbomind/kernels/attention/reference.h" #include "src/turbomind/models/llama/llama_utils.h" +#include "src/turbomind/models/llama/rotary_emb.h" #include "src/turbomind/utils/cuda_utils.h" #include "test_utils.h" #include @@ -376,6 +377,26 @@ int test_attention() rope_base[i] = kRoPEBase; } + // precompute cos/sin + const int device_id = 0; + auto allocator = std::make_unique>(device_id, false); + allocator->setStream(nullptr); + AttentionParam attn_param; + attn_param.rope.type = RopeType::kDefault; + attn_param.rope.base = kRoPEBase; + attn_param.rope.dim = kRoPEDim; + attn_param.rope.factor = 1.0f; + auto rotary_emb = std::make_unique(attn_param, nullptr, allocator.get()); + + RotaryEmbeddingV2Param rotary_param; + rotary_param.rope_theta = rope_base.data().get(); + rotary_param.q_len = cu_seqlens.data().get(); + rotary_param.k_ken = cu_kv_lens.data().get(); + rotary_param.batch_size = kBatchSize; + rotary_param.token_num = kTokenNum; + rotary_emb->forward(rotary_param); + params.cos_sin = rotary_emb->cos_sin_; + // getchar(); params.out = output_ref.data().get(); @@ -429,8 +450,6 @@ int test_attention() params.qk = qk_buf.data().get(); params.pr = pr_buf.data().get(); - params.cos_sin = nullptr; - Reference reference(kDump ? Reference::kUNFUSED : Reference::kFLASH_ATTENTION, {}); // Reference reference(Reference::kUNFUSED, {}); reference.Reshape(kInputLen, kContextLen, kHeadNum, kHeadDim, KvHeadNum, kBatchSize); diff --git a/src/turbomind/models/llama/rotary_emb.cu b/src/turbomind/models/llama/rotary_emb.cu index 72362bcb7..efb98bf85 100644 --- a/src/turbomind/models/llama/rotary_emb.cu +++ b/src/turbomind/models/llama/rotary_emb.cu @@ -180,7 +180,7 @@ RotaryEmbeddingV2::RotaryEmbeddingV2(const AttentionParam& param, cudaStream_t s } } -void RotaryEmbeddingV2::forward(const RotaryEmbeddingV2Params& params) +void RotaryEmbeddingV2::forward(const RotaryEmbeddingV2Param& params) { allocateBuffer(params.token_num); diff --git a/src/turbomind/models/llama/rotary_emb.h b/src/turbomind/models/llama/rotary_emb.h index 43fa24714..db250402a 100644 --- a/src/turbomind/models/llama/rotary_emb.h +++ b/src/turbomind/models/llama/rotary_emb.h @@ -7,7 +7,7 @@ namespace turbomind { RopeType GetRoPEType(const std::string& type); -struct RotaryEmbeddingV2Params { +struct RotaryEmbeddingV2Param { float* rope_theta; int* q_len; int* k_ken; @@ -41,7 +41,7 @@ struct RotaryEmbeddingV2 { freeBuffer(); } - void forward(const RotaryEmbeddingV2Params& params); + void forward(const RotaryEmbeddingV2Param& params); cudaStream_t const stream_; IAllocator* const allocator_; diff --git a/src/turbomind/models/llama/unified_decoder.cc b/src/turbomind/models/llama/unified_decoder.cc index c37658fd3..792c4d3b5 100644 --- a/src/turbomind/models/llama/unified_decoder.cc +++ b/src/turbomind/models/llama/unified_decoder.cc @@ -164,7 +164,7 @@ void UnifiedDecoder::forward(TensorMap* outputs, const TensorMap* inputs, con count_and_fix(decoder_output, token_num * hidden_units_, Concat("norm0", 0), 2); { - RotaryEmbeddingV2Params params; + RotaryEmbeddingV2Param params; params.rope_theta = inputs->getPtr("rope_theta"); params.q_len = cu_q_len_; params.k_ken = cu_k_len_;