Skip to content

Commit

Permalink
fix test_attention
Browse files Browse the repository at this point in the history
  • Loading branch information
irexyc committed Dec 2, 2024
1 parent ea6112e commit 0513e12
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 6 deletions.
23 changes: 21 additions & 2 deletions src/turbomind/kernels/attention/test_attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "src/turbomind/kernels/attention/attention_params.h"
#include "src/turbomind/kernels/attention/reference.h"
#include "src/turbomind/models/llama/llama_utils.h"
#include "src/turbomind/models/llama/rotary_emb.h"
#include "src/turbomind/utils/cuda_utils.h"
#include "test_utils.h"
#include <algorithm>
Expand Down Expand Up @@ -376,6 +377,26 @@ int test_attention()
rope_base[i] = kRoPEBase;
}

// precompute cos/sin
const int device_id = 0;
auto allocator = std::make_unique<Allocator<AllocatorType::CUDA>>(device_id, false);
allocator->setStream(nullptr);
AttentionParam attn_param;
attn_param.rope.type = RopeType::kDefault;
attn_param.rope.base = kRoPEBase;
attn_param.rope.dim = kRoPEDim;
attn_param.rope.factor = 1.0f;
auto rotary_emb = std::make_unique<RotaryEmbeddingV2>(attn_param, nullptr, allocator.get());

RotaryEmbeddingV2Param rotary_param;
rotary_param.rope_theta = rope_base.data().get();
rotary_param.q_len = cu_seqlens.data().get();
rotary_param.k_ken = cu_kv_lens.data().get();
rotary_param.batch_size = kBatchSize;
rotary_param.token_num = kTokenNum;
rotary_emb->forward(rotary_param);
params.cos_sin = rotary_emb->cos_sin_;

// getchar();

params.out = output_ref.data().get();
Expand Down Expand Up @@ -429,8 +450,6 @@ int test_attention()
params.qk = qk_buf.data().get();
params.pr = pr_buf.data().get();

params.cos_sin = nullptr;

Reference<T> reference(kDump ? Reference<T>::kUNFUSED : Reference<T>::kFLASH_ATTENTION, {});
// Reference<T> reference(Reference<T>::kUNFUSED, {});
reference.Reshape(kInputLen, kContextLen, kHeadNum, kHeadDim, KvHeadNum, kBatchSize);
Expand Down
2 changes: 1 addition & 1 deletion src/turbomind/models/llama/rotary_emb.cu
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ RotaryEmbeddingV2::RotaryEmbeddingV2(const AttentionParam& param, cudaStream_t s
}
}

void RotaryEmbeddingV2::forward(const RotaryEmbeddingV2Params& params)
void RotaryEmbeddingV2::forward(const RotaryEmbeddingV2Param& params)
{
allocateBuffer(params.token_num);

Expand Down
4 changes: 2 additions & 2 deletions src/turbomind/models/llama/rotary_emb.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ namespace turbomind {

RopeType GetRoPEType(const std::string& type);

struct RotaryEmbeddingV2Params {
struct RotaryEmbeddingV2Param {
float* rope_theta;
int* q_len;
int* k_ken;
Expand Down Expand Up @@ -41,7 +41,7 @@ struct RotaryEmbeddingV2 {
freeBuffer();
}

void forward(const RotaryEmbeddingV2Params& params);
void forward(const RotaryEmbeddingV2Param& params);

cudaStream_t const stream_;
IAllocator* const allocator_;
Expand Down
2 changes: 1 addition & 1 deletion src/turbomind/models/llama/unified_decoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ void UnifiedDecoder<T>::forward(TensorMap* outputs, const TensorMap* inputs, con
count_and_fix(decoder_output, token_num * hidden_units_, Concat("norm0", 0), 2);

{
RotaryEmbeddingV2Params params;
RotaryEmbeddingV2Param params;
params.rope_theta = inputs->getPtr<float>("rope_theta");
params.q_len = cu_q_len_;
params.k_ken = cu_k_len_;
Expand Down

0 comments on commit 0513e12

Please sign in to comment.