-
Notifications
You must be signed in to change notification settings - Fork 27.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support deepspeed sequence parallel #31525
base: main
Are you sure you want to change the base?
Support deepspeed sequence parallel #31525
Conversation
Raise exception when sdpa
Great, can you provide an example of data processing based on sequence paralleler? thanks |
The dataset and sampler are handled in the Trainer The data collator example is accidentally deleted when editing seq_parallel_world_size = mpu.get_sequence_parallel_world_size()
seq_parallel_world_rank = mpu.get_sequence_parallel_rank()
seq_length = input_ids.size(1)
sub_seq_length = seq_length // seq_parallel_world_size
sub_seq_start = seq_parallel_world_rank * sub_seq_length
sub_seq_end = (seq_parallel_world_rank + 1) * sub_seq_length
# There is no kv cache when training
past_key_values_length = 0
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long,
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
batch = dict(
input_ids=input_ids[:, sub_seq_start:sub_seq_end],
labels=labels[:, sub_seq_start:sub_seq_end],
position_ids=position_ids[:, sub_seq_start:sub_seq_end],
attention_mask=(input_ids != self.tokenizer.pad_token_id),
) |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
how long time this pr merge, when can it finish ? ... |
cc @SunMarc if you have the bandwidth to take a look! |
@zeyugao I carefully read your pull requests for transformers and accelerate, and pulled your code to try training. Now I have encountered a problem: when entering DistributedAttention, the q, k, v before _SeqAllToAll.apply are not [b, s/p, n, h], but still [b, s, n, h]. I checked the modified parts of the data processing, such as accelerate/data_loader.py and transformers/trainer.py, but did not find any relevant processing code. So, may I ask where the sequence splitting is done? |
@glowwormX It is in the pr description |
@zeyugao My God, I missed it, I thought there was this code in pr. Thank you for replying. |
@zeyugao Have you compared the loss of sequence parallel? After a fixed seed is added to DistributedSampler, the training data is the same. Modify the trainer.py:
However, when the same data is calculated, the average loss value after sequence parallel is different from the loss value without sequence parallel. In addition, what is the reason why starcoder does not support sdpa? I am trying to modify qwen2 and I do not know if it does not support sdpa. |
@glowwormX The main reason should be that it need to use custom loss calculation, otherwise there are some tokens (in the head and tail of each subsequence) not contributing to the final loss: https://github.com/microsoft/DeepSpeed/pull/5774/files#diff-13f25bb51b0f4019d8cb09c07204a33510dca5dccfae736baf10134f893704d5
I do not have much spare time to make the shape correct when using sdpa for startcoder2 at that time |
@zeyugao: Your implementation does not use this loss function right? It still works ok even so? |
What does this PR do?
Support the sequence parallel with Deepspeed-Ulysses.
I have tested the training on starcoder2-3b. The loss decreases normally.
Requires huggingface/accelerate#2877
I have made massive modifications to the original implementation of Deepspeed-Ulysses to support batch size dim inUse all_to_all_single is too complex to support other scatter idx and gather idxlayers.py
. It usesall_to_all_single
instead ofall_to_all
like https://github.com/InternLM/InternEvo/blob/a61d391df96c5f5c243cdea32a5044b70d6fe33e/internlm/core/parallel/comm/isp.py#L628 for better performance. I have left some comments to help the future understanding.Currently, flash attn and sdpa for llama and mistral are tested. flash attn for starcoder is also tested, the sdpa for starcoder is not supported.
It requires a special dataloader (I have made in Trainer) and data collator (with example followed). In data collator, the sequence should be divided into multiple sub-sequences. The following is an example of sub-sequences processing in the data collator.
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@muellerzr and @SunMarc