From 6d6fb4bf07fed9eb448b61a67b4c1c9f664edeac Mon Sep 17 00:00:00 2001 From: Elsa Date: Thu, 20 Jun 2024 23:50:41 +0800 Subject: [PATCH 01/14] Add original deepspeed --- src/transformers/layer.py | 109 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 109 insertions(+) create mode 100644 src/transformers/layer.py diff --git a/src/transformers/layer.py b/src/transformers/layer.py new file mode 100644 index 00000000000000..a876596fb7bbea --- /dev/null +++ b/src/transformers/layer.py @@ -0,0 +1,109 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch + +from typing import Any, Tuple +from torch import Tensor +from torch.nn import Module + +import deepspeed.comm as dist + + +def single_all_to_all(input, scatter_idx, gather_idx, group): + seq_world_size = dist.get_world_size(group) + inp_shape = list(input.shape) + inp_shape[scatter_idx] = inp_shape[scatter_idx] // seq_world_size + if scatter_idx < 2: + input_t = input.reshape( + [seq_world_size, inp_shape[scatter_idx]] + \ + inp_shape[scatter_idx + 1:] + ).contiguous() + else: + # transpose groups of heads with the seq-len parallel dimension, so that we can scatter them! + input_t = input.reshape( + [-1, seq_world_size, inp_shape[scatter_idx]] + \ + inp_shape[scatter_idx + 1:] + ).transpose(0, 1).contiguous() + + output = torch.empty_like(input_t) + dist.all_to_all_single(output, input_t, group=group) + + # if scattering the seq-dim, transpose the heads back to the original dimension + if scatter_idx < 2: + output = output.transpose(0, 1).contiguous() + + return output.reshape( + inp_shape[: gather_idx] + \ + [inp_shape[gather_idx] * seq_world_size,] + \ + inp_shape[gather_idx + 1:]).contiguous() + + +class _SeqAllToAll(torch.autograd.Function): + + @staticmethod + def forward(ctx: Any, group: dist.ProcessGroup, input: Tensor, scatter_idx: int, gather_idx: int) -> Tensor: + + ctx.group = group + ctx.scatter_idx = scatter_idx + ctx.gather_idx = gather_idx + + return single_all_to_all(input, scatter_idx, gather_idx, group) + + @staticmethod + def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None, None]: + return (None, _SeqAllToAll.apply(ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx), None, None) + + +class DistributedAttention(torch.nn.Module): + """Initialization. + + Arguments: + local_attention (Module): local attention with q,k,v + sequence_process_group (ProcessGroup): sequence parallel process group + scatter_idx (int): scatter_idx for all2all comm + gather_idx (int): gather_idx for all2all comm + """ + + def __init__( + self, + local_attention: Module, + sequence_process_group: dist.ProcessGroup, + scatter_idx: int = 2, + gather_idx: int = 0, + ) -> None: + + super(DistributedAttention, self).__init__() + self.local_attn = local_attention + self.spg = sequence_process_group + self.scatter_idx = scatter_idx + self.gather_idx = gather_idx + + def forward(self, query: Tensor, key: Tensor, value: Tensor, *args: Any, **kwargs) -> Tensor: + """ forward + + Arguments: + query (Tensor): query input to the layer + key (Tensor): key input to the layer + value (Tensor): value input to the layer + args: other args + + Returns: + * output (Tensor): context output + """ + # TODO Merge three alltoall calls into one + # TODO (Reza): change the api on the megatron-deepspeed side so that we only receive all data (q,k, and v) together! + #in shape : e.g., [s/p:h:] + query_layer = _SeqAllToAll.apply(self.spg, query, self.scatter_idx, self.gather_idx) + key_layer = _SeqAllToAll.apply(self.spg, key, self.scatter_idx, self.gather_idx) + value_layer = _SeqAllToAll.apply(self.spg, value, self.scatter_idx, self.gather_idx) + + #out shape : e.g., [s:h/p:] + context_layer = self.local_attn(query_layer, key_layer, value_layer, *args, **kwargs) + + output = _SeqAllToAll.apply(self.spg, context_layer, self.gather_idx, self.scatter_idx) + + #out e.g., [s/p::h] + return output From b5f054ced36b38ee33c001e1b8bf43157f490a08 Mon Sep 17 00:00:00 2001 From: Elsa Date: Fri, 21 Jun 2024 00:17:26 +0800 Subject: [PATCH 02/14] Support override the seed_worker in Trainer --- src/transformers/trainer.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 34cf5aa490467d..34598999f48541 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -861,6 +861,9 @@ def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: else: return RandomSampler(self.train_dataset) + + def seed_worker(self, *args, **kwargs): + return seed_worker(*args, **kwargs) def get_train_dataloader(self) -> DataLoader: """ @@ -892,7 +895,7 @@ def get_train_dataloader(self) -> DataLoader: if not isinstance(train_dataset, torch.utils.data.IterableDataset): dataloader_params["sampler"] = self._get_train_sampler() dataloader_params["drop_last"] = self.args.dataloader_drop_last - dataloader_params["worker_init_fn"] = seed_worker + dataloader_params["worker_init_fn"] = self.seed_worker dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params)) From 6f72b8a4602c230f2cbab590127d7e05b67ced78 Mon Sep 17 00:00:00 2001 From: Elsa Date: Fri, 21 Jun 2024 00:23:41 +0800 Subject: [PATCH 03/14] Add some necessary check on sequence parallel argument --- src/transformers/integrations/deepspeed.py | 14 +++++++++++++- src/transformers/modeling_utils.py | 21 ++++++++++++++++++++- src/transformers/training_args.py | 22 +++++++++++++++++++++- 3 files changed, 54 insertions(+), 3 deletions(-) diff --git a/src/transformers/integrations/deepspeed.py b/src/transformers/integrations/deepspeed.py index aae1204acf488c..be781974a1c67d 100644 --- a/src/transformers/integrations/deepspeed.py +++ b/src/transformers/integrations/deepspeed.py @@ -137,9 +137,21 @@ def trainer_config_process(self, args, auto_find_batch_size=False): Adjust the config with `TrainingArguments` values. This stage is run during `TrainingArguments` object creation. """ + + if getattr(args, 'sequence_parallel', 1) > 1: + assert is_accelerate_available(), "DeepSpeed sequence parallelism requires Accelerate, install it with 'pip install accelerate'" + + from accelerate.utils import parallel_state as mpu + mpu.initialize_model_parallel( + sequence_parallel_size=args.sequence_parallel, + ) + world_size = mpu.get_data_parallel_world_size() + else: + world_size = args.world_size + # DeepSpeed does: # train_batch_size = world_size * train_micro_batch_size_per_gpu * gradient_accumulation_steps - train_batch_size = args.world_size * args.per_device_train_batch_size * args.gradient_accumulation_steps + train_batch_size = world_size * args.per_device_train_batch_size * args.gradient_accumulation_steps self.fill_match( "train_micro_batch_size_per_gpu", args.per_device_train_batch_size, diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 1b0456eff92229..4d26d0741826a8 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -119,6 +119,7 @@ save_offload_index, set_module_tensor_to_device, ) + from accelerate.utils import parallel_state as mpu accelerate_version = version.parse(importlib.metadata.version("accelerate")) if accelerate_version >= version.parse("0.31"): @@ -3691,7 +3692,19 @@ def from_pretrained( import deepspeed logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model") - init_contexts = [deepspeed.zero.Init(config_dict_or_path=deepspeed_config())] + init_contexts + if is_accelerate_available() and mpu.model_parallel_is_initialized(): + mpu_module = mpu + sequence_data_parallel_group = mpu.get_sequence_data_parallel_group() + else: + mpu_module = None + sequence_data_parallel_group = None + init_contexts = [ + deepspeed.zero.Init( + sequence_data_parallel_group=sequence_data_parallel_group, + config_dict_or_path=deepspeed_config(), + mpu=mpu_module, + ) + ] + init_contexts elif low_cpu_mem_usage: init_contexts.append(init_empty_weights()) @@ -3878,6 +3891,12 @@ def from_pretrained( ) pass + if is_accelerate_available() and mpu.get_sequence_parallel_world_size_or_one() > 1: + if not getattr(model, "supports_sequence_parallel", False): + raise ValueError( + "The model does not support sequence parallelism." + ) + # Dispatch model with hooks on all devices if necessary if device_map is not None: device_map_kwargs = { diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 98f35501928965..c53243b7d179bd 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -1221,6 +1221,15 @@ class TrainingArguments: ) }, ) + sequence_parallel: int = field( + default=1, + metadata={ + "help": ( + "Enable sequence parallel provided by Deepspeed-Ulysses. Requires deepspeed to be enabled." + " This require handling dataset loader manually with DistributedSampler." + ) + } + ) label_smoothing_factor: float = field( default=0.0, metadata={"help": "The label smoothing epsilon to apply (zero means no label smoothing)."} ) @@ -1944,6 +1953,10 @@ def __post_init__(self): self.deepspeed_plugin.set_mixed_precision(mixed_precision) self.deepspeed_plugin.set_deepspeed_weakref() + if self.sequence_parallel > 1: + if self.deepspeed_plugin is None: + raise ValueError("sequence_parallel requires deepspeed enabled") + if self.use_cpu: self.dataloader_pin_memory = False @@ -2017,8 +2030,15 @@ def train_batch_size(self) -> int: "Using deprecated `--per_gpu_train_batch_size` argument which will be removed in a future " "version. Using `--per_device_train_batch_size` is preferred." ) + + world_size = self.n_gpu + if is_accelerate_available(): + from accelerate.utils import parallel_state as mpu + if mpu.model_parallel_is_initialized(): + world_size = mpu.get_data_parallel_world_size() + per_device_batch_size = self.per_gpu_train_batch_size or self.per_device_train_batch_size - train_batch_size = per_device_batch_size * max(1, self.n_gpu) + train_batch_size = per_device_batch_size * max(1, world_size) return train_batch_size @property From e311c0b9f1c7e46a054ced842212e297c84ee045 Mon Sep 17 00:00:00 2001 From: Elsa Date: Fri, 21 Jun 2024 00:25:00 +0800 Subject: [PATCH 04/14] Add DistributedAttention --- src/transformers/layer.py | 58 +++++++++++++++++++++++++++------------ 1 file changed, 41 insertions(+), 17 deletions(-) diff --git a/src/transformers/layer.py b/src/transformers/layer.py index a876596fb7bbea..a42db1c9b7efbb 100644 --- a/src/transformers/layer.py +++ b/src/transformers/layer.py @@ -9,36 +9,55 @@ from torch import Tensor from torch.nn import Module -import deepspeed.comm as dist +from torch import distributed as dist def single_all_to_all(input, scatter_idx, gather_idx, group): seq_world_size = dist.get_world_size(group) inp_shape = list(input.shape) inp_shape[scatter_idx] = inp_shape[scatter_idx] // seq_world_size - if scatter_idx < 2: + if scatter_idx < 2: # scatter_idx == 1, scatter sequence dim input_t = input.reshape( - [seq_world_size, inp_shape[scatter_idx]] + \ + inp_shape[: scatter_idx] + # the batch size dim + [seq_world_size, inp_shape[scatter_idx]] + # scatter the sequence dim inp_shape[scatter_idx + 1:] - ).contiguous() - else: - # transpose groups of heads with the seq-len parallel dimension, so that we can scatter them! + ).transpose(0, 1).contiguous() + else: # scatter_idx == 2, scatter heads dim input_t = input.reshape( - [-1, seq_world_size, inp_shape[scatter_idx]] + \ + [-1] + # flatten batch size and sequence dim + [seq_world_size, inp_shape[scatter_idx]] + # scatter the heads inp_shape[scatter_idx + 1:] ).transpose(0, 1).contiguous() output = torch.empty_like(input_t) dist.all_to_all_single(output, input_t, group=group) - # if scattering the seq-dim, transpose the heads back to the original dimension if scatter_idx < 2: - output = output.transpose(0, 1).contiguous() + output = output.reshape( + [ + seq_world_size, + -1, # batch size dim + inp_shape[scatter_idx] # sequence dim (scattered) + ] + inp_shape[gather_idx:] # heads dim, and the rest + ).permute( + 1, 2, 0, *list(range(3, len(output.shape))) + ) + else: + output = output.reshape( + [ + seq_world_size, + -1, # batch size dim + inp_shape[gather_idx] # sequence dim + ] + inp_shape[scatter_idx:] # heads dim (scattered), and the rest + ).transpose( + 0, 1 + ) return output.reshape( - inp_shape[: gather_idx] + \ - [inp_shape[gather_idx] * seq_world_size,] + \ - inp_shape[gather_idx + 1:]).contiguous() + inp_shape[:gather_idx] + + [seq_world_size * inp_shape[gather_idx]] + + inp_shape[gather_idx + 1:] + ).contiguous() class _SeqAllToAll(torch.autograd.Function): @@ -65,17 +84,22 @@ class DistributedAttention(torch.nn.Module): sequence_process_group (ProcessGroup): sequence parallel process group scatter_idx (int): scatter_idx for all2all comm gather_idx (int): gather_idx for all2all comm + + support and only support shape [b, s, h, d] """ def __init__( self, local_attention: Module, sequence_process_group: dist.ProcessGroup, - scatter_idx: int = 2, - gather_idx: int = 0, + scatter_idx: int = 2, # scatter dim, commonly the head dim + gather_idx: int = 1, # gather dim, commonly the sequence dim ) -> None: super(DistributedAttention, self).__init__() + + assert scatter_idx == 2 and gather_idx == 1, 'Only support shape [b, s, h, ...]' + self.local_attn = local_attention self.spg = sequence_process_group self.scatter_idx = scatter_idx @@ -95,15 +119,15 @@ def forward(self, query: Tensor, key: Tensor, value: Tensor, *args: Any, **kwarg """ # TODO Merge three alltoall calls into one # TODO (Reza): change the api on the megatron-deepspeed side so that we only receive all data (q,k, and v) together! - #in shape : e.g., [s/p:h:] + # in shape : e.g., [b,s/p:h:] query_layer = _SeqAllToAll.apply(self.spg, query, self.scatter_idx, self.gather_idx) key_layer = _SeqAllToAll.apply(self.spg, key, self.scatter_idx, self.gather_idx) value_layer = _SeqAllToAll.apply(self.spg, value, self.scatter_idx, self.gather_idx) - #out shape : e.g., [s:h/p:] + # out shape : e.g., [b,s:h/p:] context_layer = self.local_attn(query_layer, key_layer, value_layer, *args, **kwargs) output = _SeqAllToAll.apply(self.spg, context_layer, self.gather_idx, self.scatter_idx) - #out e.g., [s/p::h] + # out e.g., [b,s/p::h] return output From f2a6cc9b2547f7bd0cab111f362795723af4e4f5 Mon Sep 17 00:00:00 2001 From: Elsa Date: Fri, 21 Jun 2024 00:25:39 +0800 Subject: [PATCH 05/14] Add starcoder2 as sequence parallel supported --- .../models/starcoder2/modeling_starcoder2.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index 97ea7f9509ebea..342404d86dc89c 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -46,9 +46,13 @@ is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, + is_accelerate_available, ) from .configuration_starcoder2 import Starcoder2Config +if is_accelerate_available(): + from accelerate.utils import parallel_state as mpu + from ...layer import DistributedAttention if is_flash_attn_2_available(): from flash_attn import flash_attn_func, flash_attn_varlen_func @@ -319,6 +323,13 @@ def __init__(self, *args, **kwargs): # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + if is_accelerate_available() and mpu.sequence_parallel_is_enabled(): + self.attn_func = DistributedAttention(self._flash_attention_forward, mpu.get_sequence_parallel_group()) + self.q_len_multiplier = mpu.get_sequence_parallel_world_size() + else: + self.attn_func = self._flash_attention_forward + self.q_len_multiplier = 1 + # Ignore copy def forward( self, @@ -429,12 +440,12 @@ def forward( key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) - attn_output = self._flash_attention_forward( + attn_output = self.attn_func( query_states, key_states, value_states, attention_mask, - q_len, + q_len * self.q_len_multiplier, dropout=dropout_rate, use_sliding_windows=use_sliding_windows, ) @@ -786,6 +797,7 @@ class Starcoder2PreTrainedModel(PreTrainedModel): config_class = Starcoder2Config base_model_prefix = "model" supports_gradient_checkpointing = True + supports_sequence_parallel = True _no_split_modules = ["Starcoder2DecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True From 42fd905dfa1446f73999606e72e49272b6049c76 Mon Sep 17 00:00:00 2001 From: Elsa Date: Fri, 21 Jun 2024 00:44:54 +0800 Subject: [PATCH 06/14] Add llama, mistral Raise exception when sdpa --- .../models/llama/modeling_llama.py | 24 +++++++++++++++-- .../models/mistral/modeling_mistral.py | 26 ++++++++++++++++--- .../models/starcoder2/modeling_starcoder2.py | 9 +++++++ 3 files changed, 54 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 7cef177fef99f5..36a80d613c6f6b 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -45,9 +45,13 @@ is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, + is_accelerate_available, ) from .configuration_llama import LlamaConfig +if is_accelerate_available(): + from accelerate.utils import parallel_state as mpu + from ...layer import DistributedAttention if is_flash_attn_2_available(): from flash_attn import flash_attn_func, flash_attn_varlen_func @@ -390,6 +394,13 @@ def __init__(self, *args, **kwargs): # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + if is_accelerate_available() and mpu.sequence_parallel_is_enabled(): + self.attn_func = DistributedAttention(self._flash_attention_forward, mpu.get_sequence_parallel_group()) + self.q_len_multiplier = mpu.get_sequence_parallel_world_size() + else: + self.attn_func = self._flash_attention_forward + self.q_len_multiplier = 1 + def forward( self, hidden_states: torch.Tensor, @@ -463,8 +474,8 @@ def forward( key_states = key_states.to(target_dtype) value_states = value_states.to(target_dtype) - attn_output = self._flash_attention_forward( - query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate + attn_output = self.attn_func( + query_states, key_states, value_states, attention_mask, q_len * self.q_len_multiplier, dropout=dropout_rate ) attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() @@ -580,6 +591,15 @@ class LlamaSdpaAttention(LlamaAttention): SDPA API. """ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: We can make it support sdpa + if is_accelerate_available() and mpu.sequence_parallel_is_enabled(): + raise ValueError( + "SDPA is not supported with sequence parallelism. Please use the `flash_attention_2` implementation instead." + ) + # Adapted from LlamaAttention.forward def forward( self, diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index d266c6b4f47216..af1b16b23ac0c9 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -46,9 +46,12 @@ is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, + is_accelerate_available, ) from .configuration_mistral import MistralConfig - +if is_accelerate_available(): + from accelerate.utils import parallel_state as mpu + from ...layer import DistributedAttention if is_flash_attn_2_available(): from flash_attn import flash_attn_func, flash_attn_varlen_func @@ -303,6 +306,13 @@ def __init__(self, *args, **kwargs): # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + if is_accelerate_available() and mpu.sequence_parallel_is_enabled(): + self.attn_func = DistributedAttention(self._flash_attention_forward, mpu.get_sequence_parallel_group()) + self.q_len_multiplier = mpu.get_sequence_parallel_world_size() + else: + self.attn_func = self._flash_attention_forward + self.q_len_multiplier = 1 + def forward( self, hidden_states: torch.Tensor, @@ -412,12 +422,12 @@ def forward( key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) - attn_output = self._flash_attention_forward( + attn_output = self.attn_func( query_states, key_states, value_states, attention_mask, - q_len, + q_len * self.q_len_multiplier, dropout=dropout_rate, use_sliding_windows=use_sliding_windows, ) @@ -581,6 +591,15 @@ class MistralSdpaAttention(MistralAttention): SDPA API. """ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: We can make it support sdpa + if is_accelerate_available() and mpu.sequence_parallel_is_enabled(): + raise ValueError( + "SDPA is not supported with sequence parallelism. Please use the `flash_attention_2` implementation instead." + ) + # Adapted from MistralAttention.forward def forward( self, @@ -763,6 +782,7 @@ class MistralPreTrainedModel(PreTrainedModel): config_class = MistralConfig base_model_prefix = "model" supports_gradient_checkpointing = True + supports_sequence_parallel = True _no_split_modules = ["MistralDecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index 342404d86dc89c..18895bf5f77a02 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -610,6 +610,15 @@ class Starcoder2SdpaAttention(Starcoder2Attention): SDPA API. """ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: We can make it support sdpa + if is_accelerate_available() and mpu.sequence_parallel_is_enabled(): + raise ValueError( + "SDPA is not supported with sequence parallelism. Please use the `flash_attention_2` implementation instead." + ) + # Ignore copy def forward( self, From 1c1eed2c8f1edb91692f0498b048659d73c380e0 Mon Sep 17 00:00:00 2001 From: Elsa Date: Fri, 21 Jun 2024 10:13:14 +0800 Subject: [PATCH 07/14] Move DistributedSampler initialization into trainer --- src/transformers/trainer.py | 27 +++++++++++++++++++++------ 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 34598999f48541..2d41cd858de7d8 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -52,7 +52,7 @@ from huggingface_hub import ModelCard, create_repo, upload_folder from packaging import version from torch import nn -from torch.utils.data import DataLoader, Dataset, IterableDataset, RandomSampler, SequentialSampler +from torch.utils.data import DataLoader, Dataset, IterableDataset, RandomSampler, SequentialSampler, DistributedSampler from . import __version__ from .configuration_utils import PretrainedConfig @@ -235,6 +235,8 @@ if is_deepspeed_available(): from accelerate.utils import DeepSpeedSchedulerWrapper + + from accelerate.utils import parallel_state as mpu if is_accelerate_available("0.28.0"): from accelerate.utils import DataLoaderConfiguration @@ -842,7 +844,15 @@ def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: return None # Build the sampler. - if self.args.group_by_length: + if is_accelerate_available() and mpu.sequence_parallel_is_enabled(): + assert self.args.group_by_length is False, "Group by length is not supported with sequence parallelism." + return DistributedSampler( + dataset=self.train_dataset, + num_replicas=mpu.get_data_parallel_world_size(), + rank=mpu.get_data_parallel_rank(), + shuffle=True, + ) + elif self.args.group_by_length: if is_datasets_available() and isinstance(self.train_dataset, datasets.Dataset): lengths = ( self.train_dataset[self.args.length_column_name] @@ -861,9 +871,6 @@ def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: else: return RandomSampler(self.train_dataset) - - def seed_worker(self, *args, **kwargs): - return seed_worker(*args, **kwargs) def get_train_dataloader(self) -> DataLoader: """ @@ -895,12 +902,20 @@ def get_train_dataloader(self) -> DataLoader: if not isinstance(train_dataset, torch.utils.data.IterableDataset): dataloader_params["sampler"] = self._get_train_sampler() dataloader_params["drop_last"] = self.args.dataloader_drop_last - dataloader_params["worker_init_fn"] = self.seed_worker + dataloader_params["worker_init_fn"] = seed_worker dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params)) def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.Sampler]: + if is_accelerate_available() and mpu.sequence_parallel_is_enabled(): + return DistributedSampler( + dataset=self.eval_dataset, + num_replicas=mpu.get_data_parallel_world_size(), + rank=mpu.get_data_parallel_rank(), + shuffle=False, + ) + # Deprecated code if self.args.use_legacy_prediction_loop: if is_torch_xla_available(): From d3b0ce0e6dbcd166e8e24af9401710b9968b5c05 Mon Sep 17 00:00:00 2001 From: Elsa Date: Fri, 21 Jun 2024 11:10:16 +0800 Subject: [PATCH 08/14] Fix llama query shape when _upad_input --- src/transformers/models/llama/modeling_llama.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 36a80d613c6f6b..976c30f7d9ac39 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -556,8 +556,9 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k ) if query_length == kv_seq_len: + _batch_size, _kv_seq_len, _num_heads, _head_dim = query_layer.shape query_layer = index_first_axis( - query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k + query_layer.reshape(_batch_size * _kv_seq_len, _num_heads, _head_dim), indices_k ) cu_seqlens_q = cu_seqlens_k max_seqlen_in_batch_q = max_seqlen_in_batch_k From 8d565f8d41d6f716cb77794b8b7642db2c593ce3 Mon Sep 17 00:00:00 2001 From: Elsa Date: Fri, 21 Jun 2024 18:14:42 +0800 Subject: [PATCH 09/14] Use all_to_all for flexiablity --- src/transformers/layer.py | 79 ++++++++++++++------------------------- 1 file changed, 28 insertions(+), 51 deletions(-) diff --git a/src/transformers/layer.py b/src/transformers/layer.py index a42db1c9b7efbb..f0df627399d3cb 100644 --- a/src/transformers/layer.py +++ b/src/transformers/layer.py @@ -12,52 +12,26 @@ from torch import distributed as dist +def rank_print(*args, **kwargs): + from torch.distributed import get_rank + print(f'[Rank {get_rank()}]', *args, **kwargs) + + def single_all_to_all(input, scatter_idx, gather_idx, group): + if dist.get_world_size(group) <= 1: + return input seq_world_size = dist.get_world_size(group) - inp_shape = list(input.shape) - inp_shape[scatter_idx] = inp_shape[scatter_idx] // seq_world_size - if scatter_idx < 2: # scatter_idx == 1, scatter sequence dim - input_t = input.reshape( - inp_shape[: scatter_idx] + # the batch size dim - [seq_world_size, inp_shape[scatter_idx]] + # scatter the sequence dim - inp_shape[scatter_idx + 1:] - ).transpose(0, 1).contiguous() - else: # scatter_idx == 2, scatter heads dim - input_t = input.reshape( - [-1] + # flatten batch size and sequence dim - [seq_world_size, inp_shape[scatter_idx]] + # scatter the heads - inp_shape[scatter_idx + 1:] - ).transpose(0, 1).contiguous() - - output = torch.empty_like(input_t) - dist.all_to_all_single(output, input_t, group=group) - - if scatter_idx < 2: - output = output.reshape( - [ - seq_world_size, - -1, # batch size dim - inp_shape[scatter_idx] # sequence dim (scattered) - ] + inp_shape[gather_idx:] # heads dim, and the rest - ).permute( - 1, 2, 0, *list(range(3, len(output.shape))) - ) - else: - output = output.reshape( - [ - seq_world_size, - -1, # batch size dim - inp_shape[gather_idx] # sequence dim - ] + inp_shape[scatter_idx:] # heads dim (scattered), and the rest - ).transpose( - 0, 1 - ) - - return output.reshape( - inp_shape[:gather_idx] + - [seq_world_size * inp_shape[gather_idx]] + - inp_shape[gather_idx + 1:] - ).contiguous() + + input_list = [ + t.contiguous() for t in torch.tensor_split(input, seq_world_size, scatter_idx) + ] + output_list = [ + torch.empty_like(input_list[0]) + for _ in range(seq_world_size) + ] + # TODO: use all_to_all_single instead + dist.all_to_all(output_list, input_list, group=group) + return torch.cat(output_list, dim=gather_idx).contiguous() class _SeqAllToAll(torch.autograd.Function): @@ -98,8 +72,6 @@ def __init__( super(DistributedAttention, self).__init__() - assert scatter_idx == 2 and gather_idx == 1, 'Only support shape [b, s, h, ...]' - self.local_attn = local_attention self.spg = sequence_process_group self.scatter_idx = scatter_idx @@ -120,14 +92,19 @@ def forward(self, query: Tensor, key: Tensor, value: Tensor, *args: Any, **kwarg # TODO Merge three alltoall calls into one # TODO (Reza): change the api on the megatron-deepspeed side so that we only receive all data (q,k, and v) together! # in shape : e.g., [b,s/p:h:] - query_layer = _SeqAllToAll.apply(self.spg, query, self.scatter_idx, self.gather_idx) - key_layer = _SeqAllToAll.apply(self.spg, key, self.scatter_idx, self.gather_idx) - value_layer = _SeqAllToAll.apply(self.spg, value, self.scatter_idx, self.gather_idx) + query_layer = _SeqAllToAll.apply( + self.spg, query, self.scatter_idx, self.gather_idx) + key_layer = _SeqAllToAll.apply( + self.spg, key, self.scatter_idx, self.gather_idx) + value_layer = _SeqAllToAll.apply( + self.spg, value, self.scatter_idx, self.gather_idx) # out shape : e.g., [b,s:h/p:] - context_layer = self.local_attn(query_layer, key_layer, value_layer, *args, **kwargs) + context_layer = self.local_attn( + query_layer, key_layer, value_layer, *args, **kwargs) - output = _SeqAllToAll.apply(self.spg, context_layer, self.gather_idx, self.scatter_idx) + output = _SeqAllToAll.apply( + self.spg, context_layer, self.gather_idx, self.scatter_idx) # out e.g., [b,s/p::h] return output From 272769171a5eca23b36d877f98c63b30c83a1fdc Mon Sep 17 00:00:00 2001 From: Elsa Date: Fri, 21 Jun 2024 18:16:58 +0800 Subject: [PATCH 10/14] Support sdpa for llama and mistral --- .../models/llama/modeling_llama.py | 23 ++++++++++++------- .../models/mistral/modeling_mistral.py | 22 +++++++++++------- 2 files changed, 29 insertions(+), 16 deletions(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 976c30f7d9ac39..0fa9382d886f12 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -595,11 +595,12 @@ class LlamaSdpaAttention(LlamaAttention): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - # TODO: We can make it support sdpa if is_accelerate_available() and mpu.sequence_parallel_is_enabled(): - raise ValueError( - "SDPA is not supported with sequence parallelism. Please use the `flash_attention_2` implementation instead." - ) + self.attn_func = DistributedAttention(torch.nn.functional.scaled_dot_product_attention, mpu.get_sequence_parallel_group(), scatter_idx=1, gather_idx=2) + self.q_len_multiplier = mpu.get_sequence_parallel_world_size() + else: + self.attn_func = torch.nn.functional.scaled_dot_product_attention + self.q_len_multiplier = 1 # Adapted from LlamaAttention.forward def forward( @@ -651,7 +652,7 @@ def forward( causal_mask = attention_mask if attention_mask is not None: - causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + causal_mask = causal_mask[:, :, :, : key_states.shape[-2] * self.q_len_multiplier] # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, # Reference: https://github.com/pytorch/pytorch/issues/112577. @@ -664,7 +665,7 @@ def forward( # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. is_causal = True if causal_mask is None and q_len > 1 else False - attn_output = torch.nn.functional.scaled_dot_product_attention( + attn_output = self.attn_func( query_states, key_states, value_states, @@ -781,6 +782,7 @@ class LlamaPreTrainedModel(PreTrainedModel): config_class = LlamaConfig base_model_prefix = "model" supports_gradient_checkpointing = True + supports_sequence_parallel = True _no_split_modules = ["LlamaDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True @@ -899,6 +901,11 @@ def __init__(self, config: LlamaConfig): self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.gradient_checkpointing = False + if is_accelerate_available() and mpu.sequence_parallel_is_enabled(): + self.q_len_multiplier = mpu.get_sequence_parallel_world_size() + else: + self.q_len_multiplier = 1 + # Initialize weights and apply final processing self.post_init() @@ -951,7 +958,7 @@ def forward( if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1] * self.q_len_multiplier, device=inputs_embeds.device ) if position_ids is None: position_ids = cache_position.unsqueeze(0) @@ -1057,7 +1064,7 @@ def _update_causal_mask( dtype, device = input_tensor.dtype, input_tensor.device min_dtype = torch.finfo(dtype).min - sequence_length = input_tensor.shape[1] + sequence_length = input_tensor.shape[1] * self.q_len_multiplier if using_static_cache: target_length = past_key_values.get_max_length() else: diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index af1b16b23ac0c9..8e5589cfcc479b 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -594,11 +594,12 @@ class MistralSdpaAttention(MistralAttention): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - # TODO: We can make it support sdpa if is_accelerate_available() and mpu.sequence_parallel_is_enabled(): - raise ValueError( - "SDPA is not supported with sequence parallelism. Please use the `flash_attention_2` implementation instead." - ) + self.attn_func = DistributedAttention(torch.nn.functional.scaled_dot_product_attention, mpu.get_sequence_parallel_group(), scatter_idx=1, gather_idx=2) + self.q_len_multiplier = mpu.get_sequence_parallel_world_size() + else: + self.attn_func = torch.nn.functional.scaled_dot_product_attention + self.q_len_multiplier = 1 # Adapted from MistralAttention.forward def forward( @@ -650,7 +651,7 @@ def forward( causal_mask = attention_mask if attention_mask is not None: - causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + causal_mask = causal_mask[:, :, :, : key_states.shape[-2] * self.q_len_multiplier] # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, # Reference: https://github.com/pytorch/pytorch/issues/112577. @@ -663,7 +664,7 @@ def forward( # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. is_causal = True if causal_mask is None and q_len > 1 else False - attn_output = torch.nn.functional.scaled_dot_product_attention( + attn_output = self.attn_func( query_states, key_states, value_states, @@ -896,6 +897,11 @@ def __init__(self, config: MistralConfig): self._attn_implementation = config._attn_implementation self.norm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + if is_accelerate_available() and mpu.sequence_parallel_is_enabled(): + self.q_len_multiplier = mpu.get_sequence_parallel_world_size() + else: + self.q_len_multiplier = 1 + self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() @@ -951,7 +957,7 @@ def forward( if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1] * self.q_len_multiplier, device=inputs_embeds.device ) if position_ids is None: @@ -1073,7 +1079,7 @@ def _update_causal_mask( dtype, device = input_tensor.dtype, input_tensor.device min_dtype = torch.finfo(dtype).min - sequence_length = input_tensor.shape[1] + sequence_length = input_tensor.shape[1] * self.q_len_multiplier # SlidingWindowCache if using_sliding_window_cache: target_length = max(sequence_length, self.config.sliding_window) From fbb7e0b30aeb84d9029e998e0128647702f38ebb Mon Sep 17 00:00:00 2001 From: Elsa Date: Tue, 16 Jul 2024 20:00:17 +0800 Subject: [PATCH 11/14] Fix miss understood train_batch_size calcuation --- src/transformers/training_args.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index c53243b7d179bd..1ffa82fefa3597 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -2031,14 +2031,8 @@ def train_batch_size(self) -> int: "version. Using `--per_device_train_batch_size` is preferred." ) - world_size = self.n_gpu - if is_accelerate_available(): - from accelerate.utils import parallel_state as mpu - if mpu.model_parallel_is_initialized(): - world_size = mpu.get_data_parallel_world_size() - per_device_batch_size = self.per_gpu_train_batch_size or self.per_device_train_batch_size - train_batch_size = per_device_batch_size * max(1, world_size) + train_batch_size = per_device_batch_size * max(1, self.n_gpu) return train_batch_size @property From cf29d6d61fa976bde232e15ff6668daeaa097a65 Mon Sep 17 00:00:00 2001 From: Elsa Date: Tue, 16 Jul 2024 20:07:45 +0800 Subject: [PATCH 12/14] Fix args.world_size calcuation in model parallel --- src/transformers/training_args.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 1ffa82fefa3597..e537e118b3c74b 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -2233,7 +2233,12 @@ def world_size(self): """ requires_backends(self, ["torch"]) if self.distributed_state is not None: - return self.distributed_state.num_processes + world_size = self.distributed_state.num_processes + if is_accelerate_available(): + from accelerate.utils import parallel_state as mpu + if mpu.model_parallel_is_initialized(): + world_size = mpu.get_data_parallel_world_size() + return world_size elif is_sagemaker_mp_enabled(): return smp.dp_size() if not smp.state.cfg.prescaled_batch else smp.rdp_size() return 1 From 57488b8fde7d40d15833bd0d2a0f764a54aa0f20 Mon Sep 17 00:00:00 2001 From: Elsa Date: Tue, 16 Jul 2024 22:40:16 +0800 Subject: [PATCH 13/14] Run ruff check --- src/transformers/layer.py | 7 +++---- src/transformers/models/llama/modeling_llama.py | 4 +++- src/transformers/models/mistral/modeling_mistral.py | 5 ++++- src/transformers/models/starcoder2/modeling_starcoder2.py | 4 +++- src/transformers/trainer.py | 4 ++-- 5 files changed, 15 insertions(+), 9 deletions(-) diff --git a/src/transformers/layer.py b/src/transformers/layer.py index f0df627399d3cb..b29eedb5c8cf1b 100644 --- a/src/transformers/layer.py +++ b/src/transformers/layer.py @@ -3,13 +3,12 @@ # DeepSpeed Team -import torch - from typing import Any, Tuple -from torch import Tensor -from torch.nn import Module +import torch +from torch import Tensor from torch import distributed as dist +from torch.nn import Module def rank_print(*args, **kwargs): diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index ce5b1e323724a5..d2abd2eeff7b20 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -42,15 +42,17 @@ from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, + is_accelerate_available, is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, - is_accelerate_available, ) from .configuration_llama import LlamaConfig + if is_accelerate_available(): from accelerate.utils import parallel_state as mpu + from ...layer import DistributedAttention logger = logging.get_logger(__name__) diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 2bcae6e34514ad..be95801dc6d95a 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -40,15 +40,18 @@ from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, + is_accelerate_available, is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, - is_accelerate_available, ) from .configuration_mistral import MistralConfig + + if is_accelerate_available(): from accelerate.utils import parallel_state as mpu + from ...layer import DistributedAttention if is_flash_attn_2_available(): diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index 304738a77dd318..940bd1d6e01ad0 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -42,16 +42,18 @@ from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, + is_accelerate_available, is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, - is_accelerate_available, ) from .configuration_starcoder2 import Starcoder2Config + if is_accelerate_available(): from accelerate.utils import parallel_state as mpu + from ...layer import DistributedAttention if is_flash_attn_2_available(): diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 9ab06c5b1c7c24..a46113932bb5fa 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -53,7 +53,7 @@ from huggingface_hub import ModelCard, create_repo, upload_folder from packaging import version from torch import nn -from torch.utils.data import DataLoader, Dataset, IterableDataset, RandomSampler, SequentialSampler, DistributedSampler +from torch.utils.data import DataLoader, Dataset, DistributedSampler, IterableDataset, RandomSampler, SequentialSampler from . import __version__ from .configuration_utils import PretrainedConfig @@ -241,7 +241,7 @@ if is_deepspeed_available(): from accelerate.utils import DeepSpeedSchedulerWrapper - + from accelerate.utils import parallel_state as mpu if is_accelerate_available("0.28.0"): From 278873c582678b5088094ce8ec02b57c1c5c831f Mon Sep 17 00:00:00 2001 From: Elsa Date: Tue, 16 Jul 2024 22:43:04 +0800 Subject: [PATCH 14/14] Run ruff format --- src/transformers/integrations/deepspeed.py | 7 ++-- src/transformers/layer.py | 34 +++++-------------- src/transformers/modeling_utils.py | 4 +-- .../models/llama/modeling_llama.py | 11 ++++-- .../models/mistral/modeling_mistral.py | 11 ++++-- src/transformers/training_args.py | 3 +- 6 files changed, 34 insertions(+), 36 deletions(-) diff --git a/src/transformers/integrations/deepspeed.py b/src/transformers/integrations/deepspeed.py index be781974a1c67d..3a81881365ad0f 100644 --- a/src/transformers/integrations/deepspeed.py +++ b/src/transformers/integrations/deepspeed.py @@ -138,10 +138,13 @@ def trainer_config_process(self, args, auto_find_batch_size=False): creation. """ - if getattr(args, 'sequence_parallel', 1) > 1: - assert is_accelerate_available(), "DeepSpeed sequence parallelism requires Accelerate, install it with 'pip install accelerate'" + if getattr(args, "sequence_parallel", 1) > 1: + assert ( + is_accelerate_available() + ), "DeepSpeed sequence parallelism requires Accelerate, install it with 'pip install accelerate'" from accelerate.utils import parallel_state as mpu + mpu.initialize_model_parallel( sequence_parallel_size=args.sequence_parallel, ) diff --git a/src/transformers/layer.py b/src/transformers/layer.py index b29eedb5c8cf1b..cf12ea600279d0 100644 --- a/src/transformers/layer.py +++ b/src/transformers/layer.py @@ -11,33 +11,21 @@ from torch.nn import Module -def rank_print(*args, **kwargs): - from torch.distributed import get_rank - print(f'[Rank {get_rank()}]', *args, **kwargs) - - def single_all_to_all(input, scatter_idx, gather_idx, group): if dist.get_world_size(group) <= 1: return input seq_world_size = dist.get_world_size(group) - input_list = [ - t.contiguous() for t in torch.tensor_split(input, seq_world_size, scatter_idx) - ] - output_list = [ - torch.empty_like(input_list[0]) - for _ in range(seq_world_size) - ] + input_list = [t.contiguous() for t in torch.tensor_split(input, seq_world_size, scatter_idx)] + output_list = [torch.empty_like(input_list[0]) for _ in range(seq_world_size)] # TODO: use all_to_all_single instead dist.all_to_all(output_list, input_list, group=group) return torch.cat(output_list, dim=gather_idx).contiguous() class _SeqAllToAll(torch.autograd.Function): - @staticmethod def forward(ctx: Any, group: dist.ProcessGroup, input: Tensor, scatter_idx: int, gather_idx: int) -> Tensor: - ctx.group = group ctx.scatter_idx = scatter_idx ctx.gather_idx = gather_idx @@ -68,7 +56,6 @@ def __init__( scatter_idx: int = 2, # scatter dim, commonly the head dim gather_idx: int = 1, # gather dim, commonly the sequence dim ) -> None: - super(DistributedAttention, self).__init__() self.local_attn = local_attention @@ -77,7 +64,7 @@ def __init__( self.gather_idx = gather_idx def forward(self, query: Tensor, key: Tensor, value: Tensor, *args: Any, **kwargs) -> Tensor: - """ forward + """forward Arguments: query (Tensor): query input to the layer @@ -91,19 +78,14 @@ def forward(self, query: Tensor, key: Tensor, value: Tensor, *args: Any, **kwarg # TODO Merge three alltoall calls into one # TODO (Reza): change the api on the megatron-deepspeed side so that we only receive all data (q,k, and v) together! # in shape : e.g., [b,s/p:h:] - query_layer = _SeqAllToAll.apply( - self.spg, query, self.scatter_idx, self.gather_idx) - key_layer = _SeqAllToAll.apply( - self.spg, key, self.scatter_idx, self.gather_idx) - value_layer = _SeqAllToAll.apply( - self.spg, value, self.scatter_idx, self.gather_idx) + query_layer = _SeqAllToAll.apply(self.spg, query, self.scatter_idx, self.gather_idx) + key_layer = _SeqAllToAll.apply(self.spg, key, self.scatter_idx, self.gather_idx) + value_layer = _SeqAllToAll.apply(self.spg, value, self.scatter_idx, self.gather_idx) # out shape : e.g., [b,s:h/p:] - context_layer = self.local_attn( - query_layer, key_layer, value_layer, *args, **kwargs) + context_layer = self.local_attn(query_layer, key_layer, value_layer, *args, **kwargs) - output = _SeqAllToAll.apply( - self.spg, context_layer, self.gather_idx, self.scatter_idx) + output = _SeqAllToAll.apply(self.spg, context_layer, self.gather_idx, self.scatter_idx) # out e.g., [b,s/p::h] return output diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 650cb63c31371b..67ac1cb152e139 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -3907,9 +3907,7 @@ def from_pretrained( if is_accelerate_available() and mpu.get_sequence_parallel_world_size_or_one() > 1: if not getattr(model, "supports_sequence_parallel", False): - raise ValueError( - "The model does not support sequence parallelism." - ) + raise ValueError("The model does not support sequence parallelism.") # Dispatch model with hooks on all devices if necessary if device_map is not None: diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index d2abd2eeff7b20..f354e5c8be8996 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -492,7 +492,12 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) if is_accelerate_available() and mpu.sequence_parallel_is_enabled(): - self.attn_func = DistributedAttention(torch.nn.functional.scaled_dot_product_attention, mpu.get_sequence_parallel_group(), scatter_idx=1, gather_idx=2) + self.attn_func = DistributedAttention( + torch.nn.functional.scaled_dot_product_attention, + mpu.get_sequence_parallel_group(), + scatter_idx=1, + gather_idx=2, + ) self.q_len_multiplier = mpu.get_sequence_parallel_world_size() else: self.attn_func = torch.nn.functional.scaled_dot_product_attention @@ -866,7 +871,9 @@ def forward( if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1] * self.q_len_multiplier, device=inputs_embeds.device + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[1] * self.q_len_multiplier, + device=inputs_embeds.device, ) if position_ids is None: position_ids = cache_position.unsqueeze(0) diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index be95801dc6d95a..85cfd25c497afa 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -426,7 +426,12 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) if is_accelerate_available() and mpu.sequence_parallel_is_enabled(): - self.attn_func = DistributedAttention(torch.nn.functional.scaled_dot_product_attention, mpu.get_sequence_parallel_group(), scatter_idx=1, gather_idx=2) + self.attn_func = DistributedAttention( + torch.nn.functional.scaled_dot_product_attention, + mpu.get_sequence_parallel_group(), + scatter_idx=1, + gather_idx=2, + ) self.q_len_multiplier = mpu.get_sequence_parallel_world_size() else: self.attn_func = torch.nn.functional.scaled_dot_product_attention @@ -799,7 +804,9 @@ def forward( if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1] * self.q_len_multiplier, device=inputs_embeds.device + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[1] * self.q_len_multiplier, + device=inputs_embeds.device, ) if position_ids is None: diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 556226ba0cb6a9..fad12ef87b5fb5 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -1252,7 +1252,7 @@ class TrainingArguments: "Enable sequence parallel provided by Deepspeed-Ulysses. Requires deepspeed to be enabled." " This require handling dataset loader manually with DistributedSampler." ) - } + }, ) label_smoothing_factor: float = field( default=0.0, metadata={"help": "The label smoothing epsilon to apply (zero means no label smoothing)."} @@ -2292,6 +2292,7 @@ def world_size(self): world_size = self.distributed_state.num_processes if is_accelerate_available(): from accelerate.utils import parallel_state as mpu + if mpu.model_parallel_is_initialized(): world_size = mpu.get_data_parallel_world_size() return world_size