diff --git a/src/transformers/integrations/deepspeed.py b/src/transformers/integrations/deepspeed.py index aae1204acf488c..3a81881365ad0f 100644 --- a/src/transformers/integrations/deepspeed.py +++ b/src/transformers/integrations/deepspeed.py @@ -137,9 +137,24 @@ def trainer_config_process(self, args, auto_find_batch_size=False): Adjust the config with `TrainingArguments` values. This stage is run during `TrainingArguments` object creation. """ + + if getattr(args, "sequence_parallel", 1) > 1: + assert ( + is_accelerate_available() + ), "DeepSpeed sequence parallelism requires Accelerate, install it with 'pip install accelerate'" + + from accelerate.utils import parallel_state as mpu + + mpu.initialize_model_parallel( + sequence_parallel_size=args.sequence_parallel, + ) + world_size = mpu.get_data_parallel_world_size() + else: + world_size = args.world_size + # DeepSpeed does: # train_batch_size = world_size * train_micro_batch_size_per_gpu * gradient_accumulation_steps - train_batch_size = args.world_size * args.per_device_train_batch_size * args.gradient_accumulation_steps + train_batch_size = world_size * args.per_device_train_batch_size * args.gradient_accumulation_steps self.fill_match( "train_micro_batch_size_per_gpu", args.per_device_train_batch_size, diff --git a/src/transformers/layer.py b/src/transformers/layer.py new file mode 100644 index 00000000000000..cf12ea600279d0 --- /dev/null +++ b/src/transformers/layer.py @@ -0,0 +1,91 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Any, Tuple + +import torch +from torch import Tensor +from torch import distributed as dist +from torch.nn import Module + + +def single_all_to_all(input, scatter_idx, gather_idx, group): + if dist.get_world_size(group) <= 1: + return input + seq_world_size = dist.get_world_size(group) + + input_list = [t.contiguous() for t in torch.tensor_split(input, seq_world_size, scatter_idx)] + output_list = [torch.empty_like(input_list[0]) for _ in range(seq_world_size)] + # TODO: use all_to_all_single instead + dist.all_to_all(output_list, input_list, group=group) + return torch.cat(output_list, dim=gather_idx).contiguous() + + +class _SeqAllToAll(torch.autograd.Function): + @staticmethod + def forward(ctx: Any, group: dist.ProcessGroup, input: Tensor, scatter_idx: int, gather_idx: int) -> Tensor: + ctx.group = group + ctx.scatter_idx = scatter_idx + ctx.gather_idx = gather_idx + + return single_all_to_all(input, scatter_idx, gather_idx, group) + + @staticmethod + def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None, None]: + return (None, _SeqAllToAll.apply(ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx), None, None) + + +class DistributedAttention(torch.nn.Module): + """Initialization. + + Arguments: + local_attention (Module): local attention with q,k,v + sequence_process_group (ProcessGroup): sequence parallel process group + scatter_idx (int): scatter_idx for all2all comm + gather_idx (int): gather_idx for all2all comm + + support and only support shape [b, s, h, d] + """ + + def __init__( + self, + local_attention: Module, + sequence_process_group: dist.ProcessGroup, + scatter_idx: int = 2, # scatter dim, commonly the head dim + gather_idx: int = 1, # gather dim, commonly the sequence dim + ) -> None: + super(DistributedAttention, self).__init__() + + self.local_attn = local_attention + self.spg = sequence_process_group + self.scatter_idx = scatter_idx + self.gather_idx = gather_idx + + def forward(self, query: Tensor, key: Tensor, value: Tensor, *args: Any, **kwargs) -> Tensor: + """forward + + Arguments: + query (Tensor): query input to the layer + key (Tensor): key input to the layer + value (Tensor): value input to the layer + args: other args + + Returns: + * output (Tensor): context output + """ + # TODO Merge three alltoall calls into one + # TODO (Reza): change the api on the megatron-deepspeed side so that we only receive all data (q,k, and v) together! + # in shape : e.g., [b,s/p:h:] + query_layer = _SeqAllToAll.apply(self.spg, query, self.scatter_idx, self.gather_idx) + key_layer = _SeqAllToAll.apply(self.spg, key, self.scatter_idx, self.gather_idx) + value_layer = _SeqAllToAll.apply(self.spg, value, self.scatter_idx, self.gather_idx) + + # out shape : e.g., [b,s:h/p:] + context_layer = self.local_attn(query_layer, key_layer, value_layer, *args, **kwargs) + + output = _SeqAllToAll.apply(self.spg, context_layer, self.gather_idx, self.scatter_idx) + + # out e.g., [b,s/p::h] + return output diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index e831ba36130de2..67ac1cb152e139 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -121,6 +121,7 @@ save_offload_index, set_module_tensor_to_device, ) + from accelerate.utils import parallel_state as mpu accelerate_version = version.parse(importlib.metadata.version("accelerate")) if accelerate_version >= version.parse("0.31"): @@ -3705,7 +3706,19 @@ def from_pretrained( import deepspeed logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model") - init_contexts = [deepspeed.zero.Init(config_dict_or_path=deepspeed_config())] + init_contexts + if is_accelerate_available() and mpu.model_parallel_is_initialized(): + mpu_module = mpu + sequence_data_parallel_group = mpu.get_sequence_data_parallel_group() + else: + mpu_module = None + sequence_data_parallel_group = None + init_contexts = [ + deepspeed.zero.Init( + sequence_data_parallel_group=sequence_data_parallel_group, + config_dict_or_path=deepspeed_config(), + mpu=mpu_module, + ) + ] + init_contexts elif low_cpu_mem_usage: init_contexts.append(init_empty_weights()) @@ -3892,6 +3905,10 @@ def from_pretrained( ) pass + if is_accelerate_available() and mpu.get_sequence_parallel_world_size_or_one() > 1: + if not getattr(model, "supports_sequence_parallel", False): + raise ValueError("The model does not support sequence parallelism.") + # Dispatch model with hooks on all devices if necessary if device_map is not None: device_map_kwargs = { diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 5c0c57f3effe86..f354e5c8be8996 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -42,6 +42,7 @@ from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, + is_accelerate_available, is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, @@ -49,6 +50,11 @@ from .configuration_llama import LlamaConfig +if is_accelerate_available(): + from accelerate.utils import parallel_state as mpu + + from ...layer import DistributedAttention + logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "LlamaConfig" @@ -374,6 +380,13 @@ def __init__(self, *args, **kwargs): # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + if is_accelerate_available() and mpu.sequence_parallel_is_enabled(): + self.attn_func = DistributedAttention(_flash_attention_forward, mpu.get_sequence_parallel_group()) + self.q_len_multiplier = mpu.get_sequence_parallel_world_size() + else: + self.attn_func = _flash_attention_forward + self.q_len_multiplier = 1 + def forward( self, hidden_states: torch.Tensor, @@ -447,12 +460,12 @@ def forward( key_states = key_states.to(target_dtype) value_states = value_states.to(target_dtype) - attn_output = _flash_attention_forward( + attn_output = self.attn_func( query_states, key_states, value_states, attention_mask, - q_len, + q_len * self.q_len_multiplier, dropout=dropout_rate, sliding_window=getattr(self, "sliding_window", None), use_top_left_mask=self._flash_attn_uses_top_left_mask, @@ -475,6 +488,21 @@ class LlamaSdpaAttention(LlamaAttention): SDPA API. """ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + if is_accelerate_available() and mpu.sequence_parallel_is_enabled(): + self.attn_func = DistributedAttention( + torch.nn.functional.scaled_dot_product_attention, + mpu.get_sequence_parallel_group(), + scatter_idx=1, + gather_idx=2, + ) + self.q_len_multiplier = mpu.get_sequence_parallel_world_size() + else: + self.attn_func = torch.nn.functional.scaled_dot_product_attention + self.q_len_multiplier = 1 + # Adapted from LlamaAttention.forward def forward( self, @@ -526,7 +554,7 @@ def forward( causal_mask = attention_mask if attention_mask is not None: - causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + causal_mask = causal_mask[:, :, :, : key_states.shape[-2] * self.q_len_multiplier] # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, # Reference: https://github.com/pytorch/pytorch/issues/112577. @@ -539,7 +567,7 @@ def forward( # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. is_causal = True if causal_mask is None and q_len > 1 else False - attn_output = torch.nn.functional.scaled_dot_product_attention( + attn_output = self.attn_func( query_states, key_states, value_states, @@ -663,6 +691,7 @@ class LlamaPreTrainedModel(PreTrainedModel): config_class = LlamaConfig base_model_prefix = "model" supports_gradient_checkpointing = True + supports_sequence_parallel = True _no_split_modules = ["LlamaDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True @@ -781,6 +810,11 @@ def __init__(self, config: LlamaConfig): self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.gradient_checkpointing = False + if is_accelerate_available() and mpu.sequence_parallel_is_enabled(): + self.q_len_multiplier = mpu.get_sequence_parallel_world_size() + else: + self.q_len_multiplier = 1 + # Initialize weights and apply final processing self.post_init() @@ -837,7 +871,9 @@ def forward( if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[1] * self.q_len_multiplier, + device=inputs_embeds.device, ) if position_ids is None: position_ids = cache_position.unsqueeze(0) @@ -943,7 +979,7 @@ def _update_causal_mask( dtype, device = input_tensor.dtype, input_tensor.device min_dtype = torch.finfo(dtype).min - sequence_length = input_tensor.shape[1] + sequence_length = input_tensor.shape[1] * self.q_len_multiplier if using_static_cache: target_length = past_key_values.get_max_length() else: diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 5bd74a71e772e5..85cfd25c497afa 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -40,6 +40,7 @@ from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, + is_accelerate_available, is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, logging, @@ -48,6 +49,11 @@ from .configuration_mistral import MistralConfig +if is_accelerate_available(): + from accelerate.utils import parallel_state as mpu + + from ...layer import DistributedAttention + if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward @@ -283,6 +289,13 @@ def __init__(self, *args, **kwargs): # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + if is_accelerate_available() and mpu.sequence_parallel_is_enabled(): + self.attn_func = DistributedAttention(_flash_attention_forward, mpu.get_sequence_parallel_group()) + self.q_len_multiplier = mpu.get_sequence_parallel_world_size() + else: + self.attn_func = _flash_attention_forward + self.q_len_multiplier = 1 + def forward( self, hidden_states: torch.Tensor, @@ -380,12 +393,12 @@ def forward( key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) - attn_output = _flash_attention_forward( + attn_output = self.attn_func( query_states, key_states, value_states, attention_mask, - q_len, + q_len * self.q_len_multiplier, dropout=dropout_rate, sliding_window=getattr(self.config, "sliding_window", None), use_top_left_mask=self._flash_attn_uses_top_left_mask, @@ -409,6 +422,21 @@ class MistralSdpaAttention(MistralAttention): SDPA API. """ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + if is_accelerate_available() and mpu.sequence_parallel_is_enabled(): + self.attn_func = DistributedAttention( + torch.nn.functional.scaled_dot_product_attention, + mpu.get_sequence_parallel_group(), + scatter_idx=1, + gather_idx=2, + ) + self.q_len_multiplier = mpu.get_sequence_parallel_world_size() + else: + self.attn_func = torch.nn.functional.scaled_dot_product_attention + self.q_len_multiplier = 1 + # Adapted from MistralAttention.forward def forward( self, @@ -460,7 +488,7 @@ def forward( causal_mask = attention_mask if attention_mask is not None: - causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + causal_mask = causal_mask[:, :, :, : key_states.shape[-2] * self.q_len_multiplier] # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, # Reference: https://github.com/pytorch/pytorch/issues/112577. @@ -473,7 +501,7 @@ def forward( # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. is_causal = True if causal_mask is None and q_len > 1 else False - attn_output = torch.nn.functional.scaled_dot_product_attention( + attn_output = self.attn_func( query_states, key_states, value_states, @@ -598,6 +626,7 @@ class MistralPreTrainedModel(PreTrainedModel): config_class = MistralConfig base_model_prefix = "model" supports_gradient_checkpointing = True + supports_sequence_parallel = True _no_split_modules = ["MistralDecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True @@ -711,6 +740,11 @@ def __init__(self, config: MistralConfig): self._attn_implementation = config._attn_implementation self.norm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + if is_accelerate_available() and mpu.sequence_parallel_is_enabled(): + self.q_len_multiplier = mpu.get_sequence_parallel_world_size() + else: + self.q_len_multiplier = 1 + self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() @@ -770,7 +804,9 @@ def forward( if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[1] * self.q_len_multiplier, + device=inputs_embeds.device, ) if position_ids is None: @@ -892,7 +928,7 @@ def _update_causal_mask( dtype, device = input_tensor.dtype, input_tensor.device min_dtype = torch.finfo(dtype).min - sequence_length = input_tensor.shape[1] + sequence_length = input_tensor.shape[1] * self.q_len_multiplier # SlidingWindowCache if using_sliding_window_cache: target_length = max(sequence_length, self.config.sliding_window) diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index 430befd24ae364..940bd1d6e01ad0 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -42,6 +42,7 @@ from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, + is_accelerate_available, is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, logging, @@ -50,6 +51,11 @@ from .configuration_starcoder2 import Starcoder2Config +if is_accelerate_available(): + from accelerate.utils import parallel_state as mpu + + from ...layer import DistributedAttention + if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward @@ -299,6 +305,13 @@ def __init__(self, *args, **kwargs): # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + if is_accelerate_available() and mpu.sequence_parallel_is_enabled(): + self.attn_func = DistributedAttention(_flash_attention_forward, mpu.get_sequence_parallel_group()) + self.q_len_multiplier = mpu.get_sequence_parallel_world_size() + else: + self.attn_func = _flash_attention_forward + self.q_len_multiplier = 1 + # Ignore copy def forward( self, @@ -398,12 +411,12 @@ def forward( key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) - attn_output = _flash_attention_forward( + attn_output = self.attn_func( query_states, key_states, value_states, attention_mask, - q_len, + q_len * self.q_len_multiplier, dropout=dropout_rate, sliding_window=getattr(self.config, "sliding_window", None), is_causal=self.is_causal, @@ -428,6 +441,15 @@ class Starcoder2SdpaAttention(Starcoder2Attention): SDPA API. """ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: We can make it support sdpa + if is_accelerate_available() and mpu.sequence_parallel_is_enabled(): + raise ValueError( + "SDPA is not supported with sequence parallelism. Please use the `flash_attention_2` implementation instead." + ) + # Ignore copy def forward( self, @@ -622,6 +644,7 @@ class Starcoder2PreTrainedModel(PreTrainedModel): config_class = Starcoder2Config base_model_prefix = "model" supports_gradient_checkpointing = True + supports_sequence_parallel = True _no_split_modules = ["Starcoder2DecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 485f6cd61e0e7b..a46113932bb5fa 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -53,7 +53,7 @@ from huggingface_hub import ModelCard, create_repo, upload_folder from packaging import version from torch import nn -from torch.utils.data import DataLoader, Dataset, IterableDataset, RandomSampler, SequentialSampler +from torch.utils.data import DataLoader, Dataset, DistributedSampler, IterableDataset, RandomSampler, SequentialSampler from . import __version__ from .configuration_utils import PretrainedConfig @@ -242,6 +242,8 @@ if is_deepspeed_available(): from accelerate.utils import DeepSpeedSchedulerWrapper + from accelerate.utils import parallel_state as mpu + if is_accelerate_available("0.28.0"): from accelerate.utils import DataLoaderConfiguration @@ -854,7 +856,15 @@ def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: return None # Build the sampler. - if self.args.group_by_length: + if is_accelerate_available() and mpu.sequence_parallel_is_enabled(): + assert self.args.group_by_length is False, "Group by length is not supported with sequence parallelism." + return DistributedSampler( + dataset=self.train_dataset, + num_replicas=mpu.get_data_parallel_world_size(), + rank=mpu.get_data_parallel_rank(), + shuffle=True, + ) + elif self.args.group_by_length: if is_datasets_available() and isinstance(self.train_dataset, datasets.Dataset): lengths = ( self.train_dataset[self.args.length_column_name] @@ -910,6 +920,14 @@ def get_train_dataloader(self) -> DataLoader: return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params)) def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.Sampler]: + if is_accelerate_available() and mpu.sequence_parallel_is_enabled(): + return DistributedSampler( + dataset=self.eval_dataset, + num_replicas=mpu.get_data_parallel_world_size(), + rank=mpu.get_data_parallel_rank(), + shuffle=False, + ) + # Deprecated code if self.args.use_legacy_prediction_loop: if is_torch_xla_available(): diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index b87a3d9d0554d2..fad12ef87b5fb5 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -1245,6 +1245,15 @@ class TrainingArguments: ) }, ) + sequence_parallel: int = field( + default=1, + metadata={ + "help": ( + "Enable sequence parallel provided by Deepspeed-Ulysses. Requires deepspeed to be enabled." + " This require handling dataset loader manually with DistributedSampler." + ) + }, + ) label_smoothing_factor: float = field( default=0.0, metadata={"help": "The label smoothing epsilon to apply (zero means no label smoothing)."} ) @@ -1990,6 +1999,10 @@ def __post_init__(self): self.deepspeed_plugin.set_mixed_precision(mixed_precision) self.deepspeed_plugin.set_deepspeed_weakref() + if self.sequence_parallel > 1: + if self.deepspeed_plugin is None: + raise ValueError("sequence_parallel requires deepspeed enabled") + if self.use_cpu: self.dataloader_pin_memory = False @@ -2069,6 +2082,7 @@ def train_batch_size(self) -> int: "Using deprecated `--per_gpu_train_batch_size` argument which will be removed in a future " "version. Using `--per_device_train_batch_size` is preferred." ) + per_device_batch_size = self.per_gpu_train_batch_size or self.per_device_train_batch_size train_batch_size = per_device_batch_size * max(1, self.n_gpu) return train_batch_size @@ -2275,7 +2289,13 @@ def world_size(self): """ requires_backends(self, ["torch"]) if self.distributed_state is not None: - return self.distributed_state.num_processes + world_size = self.distributed_state.num_processes + if is_accelerate_available(): + from accelerate.utils import parallel_state as mpu + + if mpu.model_parallel_is_initialized(): + world_size = mpu.get_data_parallel_world_size() + return world_size elif is_sagemaker_mp_enabled(): return smp.dp_size() if not smp.state.cfg.prescaled_batch else smp.rdp_size() return 1