-
Notifications
You must be signed in to change notification settings - Fork 432
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor turbomind attention by precomputing cos/sin #2801
Open
irexyc
wants to merge
7
commits into
InternLM:main
Choose a base branch
from
irexyc:rope
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 1 commit
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
05d011c
use precomputed cos sin
irexyc 7b74b72
remove unused
irexyc 589cacb
Merge remote-tracking branch 'origin/main' into rope
irexyc 45f0968
Merge remote-tracking branch 'origin/main' into rope
irexyc 0e4c315
split rope params
irexyc ea6112e
remove prefix yarn_, llama3_
irexyc 0513e12
fix test_attention
irexyc File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,15 +5,7 @@ | |
|
||
namespace turbomind { | ||
|
||
enum class RotaryScalingType | ||
{ | ||
kDefault, | ||
kLinear, | ||
kDynamic, | ||
kYarn, | ||
kLlama3, | ||
kMrope | ||
}; | ||
RotaryScalingType GetRoPEType(const std::string& type); | ||
|
||
struct RotaryEmbeddingV2Params { | ||
float* rope_theta; | ||
|
@@ -23,6 +15,19 @@ struct RotaryEmbeddingV2Params { | |
int token_num; | ||
}; | ||
|
||
struct InnerYarnRopeParam { | ||
float attention_factor; | ||
float yarn_ramp_inv_factor_div_2; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we can remove the prefix "yarn_" |
||
float yarn_ramp_inv_factor_mul_min; | ||
float yarn_inv_scaling_factor; | ||
}; | ||
|
||
struct InnerLlama3RopeParam { | ||
float llama3_inv_scaling_factor; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the prefix "llama3_" can be removed |
||
float llama3_alpha; | ||
float llama3_beta; | ||
}; | ||
|
||
struct RotaryEmbeddingV2 { | ||
|
||
RotaryEmbeddingV2(const AttentionParam& param, cudaStream_t stream, IAllocator* allocator); | ||
|
@@ -38,28 +43,20 @@ struct RotaryEmbeddingV2 { | |
|
||
void forward(const RotaryEmbeddingV2Params& params); | ||
|
||
RotaryScalingType type_; | ||
cudaStream_t const stream_; | ||
IAllocator* const allocator_; | ||
|
||
int dim_; | ||
RotaryScalingType type_; | ||
float inv_factor_{1.0}; | ||
|
||
union { | ||
InnerYarnRopeParam yarn_; | ||
InnerLlama3RopeParam llama3_; | ||
}; | ||
|
||
// output | ||
float* cos_sin_; // num_token x dim, (cos, sin, ...) | ||
|
||
int dim_; | ||
// default, linear, dynamic | ||
float attention_factor_; | ||
float rope_scaling_factor_; | ||
float inv_scale_factor_; | ||
// llama3 | ||
float llama3_inv_scaling_factor_; | ||
float llama3_alpha_; | ||
float llama3_beta_; | ||
// yarn | ||
float yarn_ramp_inv_factor_div_2_; | ||
float yarn_ramp_inv_factor_mul_min_; | ||
float yarn_inv_scaling_factor_; | ||
// mrope | ||
int3 mrope_section_; | ||
}; | ||
|
||
}; // namespace turbomind |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
RotaryScalingType -> RopeType