From 761820ac0ef692ca54f5735c90185ad1e96723a5 Mon Sep 17 00:00:00 2001 From: Elsa Date: Thu, 20 Jun 2024 22:31:52 +0800 Subject: [PATCH 1/3] Add support passing mpu into deepspeed --- src/accelerate/accelerator.py | 10 +- src/accelerate/data_loader.py | 12 +- src/accelerate/utils/parallel_state.py | 812 +++++++++++++++++++++++++ 3 files changed, 832 insertions(+), 2 deletions(-) create mode 100644 src/accelerate/utils/parallel_state.py diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index 308861589d4..4fc8f9821d5 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -106,6 +106,7 @@ from .utils.constants import FSDP_PYTORCH_VERSION, PROFILE_PATTERN_NAME from .utils.modeling import get_state_dict_offloaded_model from .utils.other import is_compiled_module +from .utils import parallel_state as mpu if is_deepspeed_available(): @@ -1634,11 +1635,16 @@ def _prepare_deepspeed(self, *args): gradient_accumulation_steps=self.gradient_accumulation_steps, ) + if mpu.model_parallel_is_initialized(): + world_size = mpu.get_data_parallel_world_size() + else: + world_size = self.num_processes + config_kwargs = { "train_micro_batch_size_per_gpu": batch_size_per_device, "train_batch_size": batch_size_per_device * deepspeed_plugin.get_value("gradient_accumulation_steps") - * self.num_processes, + * world_size, "gradient_clipping": 1.0, "zero_optimization.stage3_gather_16bit_weights_on_model_save": False, } @@ -1758,6 +1764,8 @@ def _prepare_deepspeed(self, *args): deepspeed_plugin.deepspeed_config_process(must_match=False, **config_kwargs) self.deepspeed_config = deepspeed_plugin.deepspeed_config kwargs = dict(model=model, config_params=self.deepspeed_config) + if mpu.model_parallel_is_initialized(): + kwargs["mpu"] = mpu if optimizer is not None: if isinstance(optimizer, (DummyOptim)): kwargs["model_parameters"] = optimizer.params diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index fcf6631f162..f9f45ccf04b 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -34,6 +34,7 @@ slice_tensors, synchronize_rng_states, ) +from torch.utils.data import DistributedSampler logger = get_logger(__name__) @@ -941,8 +942,17 @@ def prepare_data_loader( generator = torch.Generator().manual_seed(42) dataloader.generator = generator dataloader.sampler.generator = generator + + is_distributed_sampler = isinstance( + dataloader.sampler.sampler if sampler_is_batch_sampler else dataloader.sampler, + DistributedSampler + ) + + if is_distributed_sampler and split_batches: + raise ValueError("Using `split_batches=True` with a `DistributedSampler` is not supported.") + # No change if no multiprocess - if (num_processes != 1 or state.distributed_type == DistributedType.MEGATRON_LM) and not dispatch_batches: + if (num_processes != 1 or state.distributed_type == DistributedType.MEGATRON_LM) and not dispatch_batches and not is_distributed_sampler: if isinstance(new_dataset, IterableDataset): if getattr(dataloader.dataset, "generator", None) is not None: synchronized_generator = dataloader.dataset.generator diff --git a/src/accelerate/utils/parallel_state.py b/src/accelerate/utils/parallel_state.py new file mode 100644 index 00000000000..b7000cba489 --- /dev/null +++ b/src/accelerate/utils/parallel_state.py @@ -0,0 +1,812 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + +"""Model and data parallel groups.""" + +""" +mpu: Optional: A model parallelism unit object that implements +get_{model,data}_parallel_{rank,group,world_size}() +""" + + +# Intra-layer model parallel group that the current rank belongs to. +import torch +from torch import distributed as dist +from typing import Optional +_TENSOR_MODEL_PARALLEL_GROUP = None +# Inter-layer model parallel group that the current rank belongs to. +_PIPELINE_MODEL_PARALLEL_GROUP = None +# Model parallel group (both intra- and pipeline) that the current rank belongs to. +_MODEL_PARALLEL_GROUP = None +# Embedding group. +_EMBEDDING_GROUP = None +# Position embedding group. +_POSITION_EMBEDDING_GROUP = None +# Data parallel group that the current rank belongs to. +_DATA_PARALLEL_GROUP = None +_DATA_PARALLEL_GROUP_GLOO = None +_DATA_PARALLEL_ALL_GROUP_RANKS = None +# FP8 amax reduction group. +_AMAX_REDUCTION_GROUP = None + +_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None +_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None +_PIPELINE_MODEL_PARALLEL_SPLIT_RANK = None + +# These values enable us to change the mpu sizes on the fly. +_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None +_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None +_MPU_TENSOR_MODEL_PARALLEL_RANK = None +_MPU_PIPELINE_MODEL_PARALLEL_RANK = None + +# A list of ranks that have a copy of the embedding. +_EMBEDDING_GLOBAL_RANKS = None + +# A list of ranks that have a copy of the position embedding. +_POSITION_EMBEDDING_GLOBAL_RANKS = None + +# A list of global ranks for each pipeline group to ease calculation of the source +# rank when broadcasting from the first or last pipeline stage. +_PIPELINE_GLOBAL_RANKS = None + +# For DeepSpeed's sequence parallel +_SEQUENCE_PARALLEL_GROUP = None +_SEQUENCE_PARALLEL_WORLD_SIZE = None +_SEQUENCE_PARALLEL_RANK = None + +# This group includes processes for both data and sequence parallelisms. +# We use this group to reduce gradients and shard parameters and optimizer stages for ZeRO. +_SEQUENCE_DATA_PARALLEL_GROUP = None +_SEQUENCE_DATA_PARALLEL_WORLD_SIZE = None +_SEQUENCE_DATA_PARALLEL_RANK = None + +# A list of global ranks for each data parallel group to ease calculation of the source +# rank when broadcasting weights from src to all other data parallel ranks +_DATA_PARALLEL_GLOBAL_RANKS = None + +# # Memory buffers to avoid dynamic memory allocation +# _GLOBAL_MEMORY_BUFFER = None + + +def initialize_model_parallel( + tensor_model_parallel_size: int = 1, + pipeline_model_parallel_size: int = 1, + sequence_parallel_size: int = 1, + # virtual_pipeline_model_parallel_size: Optional[int] = None, + # pipeline_model_parallel_split_rank: Optional[int] = None, + # use_fp8: bool = False, + use_distributed_optimizer: bool = False, +) -> None: + """Initialize model data parallel groups. + + Arguments: + tensor_model_parallel_size (int, default = 1): + The number of GPUs to split individual tensors across. + + pipeline_model_parallel_size (int, default = 1): + The number of tensor parallel GPU groups to split the + Transformer layers across. For example, if + tensor_model_parallel_size is 4 and + pipeline_model_parallel_size is 2, the model will be split + into 2 groups of 4 GPUs. + + virtual_pipeline_model_parallel_size (int, optional): + The number of stages that each pipeline group will have, + interleaving as necessary. If None, no interleaving is + performed. For example, if tensor_model_parallel_size is 1, + pipeline_model_parallel_size is 4, + virtual_pipeline_model_parallel_size is 2, and there are + 16 transformer layers in the model, the model will be + split into 8 stages with two layers each and each GPU + would get 2 stages as such (layer number starting with 1): + + GPU 0: [1, 2] [9, 10] + GPU 1: [3, 4] [11, 12] + GPU 2: [5, 6] [13, 14] + GPU 3: [7, 8] [15, 16] + + pipeline_model_parallel_split_rank (int, optional): + For models with both an encoder and decoder, the rank in + pipeline to switch between encoder and decoder (i.e. the + first rank of the decoder). This allows the user to set + the pipeline parallel size of the encoder and decoder + independently. For example, if + pipeline_model_parallel_size is 8 and + pipeline_model_parallel_split_rank is 3, then ranks 0-2 + will be the encoder and ranks 3-7 will be the decoder. + + use_fp8 (bool, default = False): + Construct GPU groups needed for FP8 training, namely for + amax reduction across the product of the data-parallel and + tensor-parallel groups. + + Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we + use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize + the model pipeline. The present function will + create 8 tensor model-parallel groups, 4 pipeline model-parallel groups + and 8 data-parallel groups as: + 8 data_parallel groups: + [g0, g2], [g1, g3], [g4, g6], [g5, g7], [g8, g10], [g9, g11], [g12, g14], [g13, g15] + 8 tensor model-parallel groups: + [g0, g1], [g2, g3], [g4, g5], [g6, g7], [g8, g9], [g10, g11], [g12, g13], [g14, g15] + 4 pipeline model-parallel groups: + [g0, g4, g8, g12], [g1, g5, g9, g13], [g2, g6, g10, g14], [g3, g7, g11, g15] + Note that for efficiency, the caller should make sure adjacent ranks + are on the same DGX box. For example if we are using 2 DGX-1 boxes + with a total of 16 GPUs, rank 0 to 7 belong to the first box and + ranks 8 to 15 belong to the second box. + + """ + # Get world size and rank. Ensure some consistencies. + assert dist.is_initialized() + world_size: int = dist.get_world_size() + + assert tensor_model_parallel_size == 1, 'tensor model parallel size should be 1' + assert pipeline_model_parallel_size == 1, 'pipeline model parallel size should be 1' + + if world_size % (tensor_model_parallel_size * pipeline_model_parallel_size) != 0: + raise RuntimeError( + f"world_size ({world_size}) is not divisible by tensor_model_parallel_size " + f"({tensor_model_parallel_size}) x pipeline_model_parallel_size ({pipeline_model_parallel_size})" + ) + + enable_ds_sequence_parallel = sequence_parallel_size > 1 + if enable_ds_sequence_parallel: + assert tensor_model_parallel_size == 1 and pipeline_model_parallel_size == 1, \ + 'DeepSpeed\'s sequence parallel does not work with tensor parallel or pipeline parallel' + + if world_size % sequence_parallel_size != 0: + raise RuntimeError( + f"world_size ({world_size}) is not divisible by sequence_parallel_size {sequence_parallel_size})" + ) + + data_parallel_size: int = world_size // ( + tensor_model_parallel_size * pipeline_model_parallel_size * sequence_parallel_size) + sequence_data_parallel_size: int = sequence_parallel_size * data_parallel_size + + num_tensor_model_parallel_groups: int = world_size // tensor_model_parallel_size + num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size + num_data_parallel_groups: int = world_size // data_parallel_size + num_sequence_parallel_groups: int = world_size // sequence_parallel_size + num_sequence_data_parallel_groups: int = world_size // sequence_parallel_size // data_parallel_size + + # if virtual_pipeline_model_parallel_size is not None: + # if not pipeline_model_parallel_size > 2: + # raise RuntimeError( + # "pipeline-model-parallel size should be greater than 2 with " "interleaved schedule") + # global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK + # global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE + # _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = 0 + # _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = virtual_pipeline_model_parallel_size + + # if pipeline_model_parallel_split_rank is not None: + # global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK + # _PIPELINE_MODEL_PARALLEL_SPLIT_RANK = pipeline_model_parallel_split_rank + + rank = dist.get_rank() + + # Build the data-parallel groups. + global _DATA_PARALLEL_GROUP + global _DATA_PARALLEL_GROUP_GLOO + global _DATA_PARALLEL_GLOBAL_RANKS + assert _DATA_PARALLEL_GROUP is None, 'data parallel group is already initialized' + all_data_parallel_group_ranks = [] + for i in range(pipeline_model_parallel_size): + start_rank = i * num_pipeline_model_parallel_groups + end_rank = (i + 1) * num_pipeline_model_parallel_groups + + if sequence_parallel_size > 1: + tp_or_sp_size = sequence_parallel_size + else: + tp_or_sp_size = tensor_model_parallel_size + + for j in range(tp_or_sp_size): + ranks = range(start_rank + j, end_rank, tp_or_sp_size) + all_data_parallel_group_ranks.append(list(ranks)) + group = dist.new_group(ranks) + if use_distributed_optimizer: + group_gloo = dist.new_group(ranks, backend="gloo") + else: + group_gloo = None + if rank in ranks: + _DATA_PARALLEL_GROUP = group + _DATA_PARALLEL_GROUP_GLOO = group_gloo + _DATA_PARALLEL_GLOBAL_RANKS = ranks + + global _DATA_PARALLEL_ALL_GROUP_RANKS + _DATA_PARALLEL_ALL_GROUP_RANKS = all_data_parallel_group_ranks + + # Build the sequence parallel groups. + global _SEQUENCE_PARALLEL_GROUP + assert _SEQUENCE_PARALLEL_GROUP is None, \ + 'sequence parallel group is already initialized' + for i in range(num_sequence_parallel_groups): + ranks = range(i * sequence_parallel_size, + (i + 1) * sequence_parallel_size) + group = dist.new_group(ranks) + if rank in ranks: + _SEQUENCE_PARALLEL_GROUP = group + + # Build the sequence data parallel groups. + global _SEQUENCE_DATA_PARALLEL_GROUP + assert _SEQUENCE_DATA_PARALLEL_GROUP is None, \ + 'sequence data parallel group is already initialized' + all_data_sequence_parallel_group_ranks = [] + + if enable_ds_sequence_parallel: + for i in range(num_sequence_data_parallel_groups): + ranks = range(i * sequence_data_parallel_size, + (i + 1) * sequence_data_parallel_size) + group = dist.new_group(ranks) + all_data_sequence_parallel_group_ranks.append(list(ranks)) + if rank in ranks: + _SEQUENCE_DATA_PARALLEL_GROUP = group + else: + _SEQUENCE_DATA_PARALLEL_GROUP = _DATA_PARALLEL_GROUP + + # Build the model-parallel groups. + global _MODEL_PARALLEL_GROUP + assert _MODEL_PARALLEL_GROUP is None, 'model parallel group is already initialized' + num_model_parallel_groups = sequence_data_parallel_size if enable_ds_sequence_parallel else data_parallel_size + model_parallel_group_ranks = all_data_sequence_parallel_group_ranks if enable_ds_sequence_parallel else all_data_parallel_group_ranks + for i in range(num_model_parallel_groups): + ranks = [parallel_group_ranks[i] + for parallel_group_ranks in model_parallel_group_ranks] + group = dist.new_group(ranks) + if rank in ranks: + _MODEL_PARALLEL_GROUP = group + + # Build the tensor model-parallel groups. + global _TENSOR_MODEL_PARALLEL_GROUP + assert _TENSOR_MODEL_PARALLEL_GROUP is None, 'tensor model parallel group is already initialized' + for i in range(num_tensor_model_parallel_groups): + ranks = range(i * tensor_model_parallel_size, + (i + 1) * tensor_model_parallel_size) + group = dist.new_group(ranks) + if rank in ranks: + _TENSOR_MODEL_PARALLEL_GROUP = group + + # Build the pipeline model-parallel groups and embedding groups + # (first and last rank in each pipeline model-parallel group). + global _PIPELINE_MODEL_PARALLEL_GROUP + global _PIPELINE_GLOBAL_RANKS + assert _PIPELINE_MODEL_PARALLEL_GROUP is None, 'pipeline model parallel group is already initialized' + global _EMBEDDING_GROUP + global _EMBEDDING_GLOBAL_RANKS + assert _EMBEDDING_GROUP is None, 'embedding group is already initialized' + global _POSITION_EMBEDDING_GROUP + global _POSITION_EMBEDDING_GLOBAL_RANKS + assert _POSITION_EMBEDDING_GROUP is None, 'position embedding group is already initialized' + for i in range(num_pipeline_model_parallel_groups): + ranks = range(i, world_size, num_pipeline_model_parallel_groups) + group = dist.new_group(ranks) + if rank in ranks: + _PIPELINE_MODEL_PARALLEL_GROUP = group + _PIPELINE_GLOBAL_RANKS = ranks + # Setup embedding group (to exchange gradients between + # first and last stages). + if len(ranks) > 1: + embedding_ranks = [ranks[0], ranks[-1]] + position_embedding_ranks = [ranks[0]] + # if pipeline_model_parallel_split_rank is not None: + # if ranks[pipeline_model_parallel_split_rank] not in embedding_ranks: + # embedding_ranks = [ + # ranks[0], ranks[pipeline_model_parallel_split_rank], ranks[-1]] + # if ranks[pipeline_model_parallel_split_rank] not in position_embedding_ranks: + # position_embedding_ranks = [ + # ranks[0], ranks[pipeline_model_parallel_split_rank]] + else: + embedding_ranks = ranks + position_embedding_ranks = ranks + + group = dist.new_group(embedding_ranks) + if rank in embedding_ranks: + _EMBEDDING_GROUP = group + if rank in ranks: + _EMBEDDING_GLOBAL_RANKS = embedding_ranks + + group = dist.new_group(position_embedding_ranks) + if rank in position_embedding_ranks: + _POSITION_EMBEDDING_GROUP = group + if rank in ranks: + _POSITION_EMBEDDING_GLOBAL_RANKS = position_embedding_ranks + + # Build the FP8 groups. + global _AMAX_REDUCTION_GROUP + assert _AMAX_REDUCTION_GROUP is None, \ + 'FP8 amax reduction group is already initialized' + # if use_fp8: + # amax_group_size: int = tensor_model_parallel_size * data_parallel_size + # num_amax_groups: int = world_size // amax_group_size + # for i in range(num_amax_groups): + # start_rank = i * amax_group_size + # end_rank = (i + 1) * amax_group_size + # ranks = range(start_rank, end_rank) + # group = dist.new_group(ranks) + # if rank in ranks: + # _AMAX_REDUCTION_GROUP = group + + # Initialize global memory buffer + # This isn't really "parallel state" but there isn't another good place to + # put this. If we end up with a more generic initialization of megatron-core + # we could stick it there + # _set_global_memory_buffer() + + +def is_uninitialized(): + """Useful for code segments that may be accessed with or without mpu initialization""" + return _DATA_PARALLEL_GROUP is None + + +def model_parallel_is_initialized(): + """Check if model and data parallel groups are initialized.""" + if _TENSOR_MODEL_PARALLEL_GROUP is None or _PIPELINE_MODEL_PARALLEL_GROUP is None or _DATA_PARALLEL_GROUP is None: + return False + return True + + +def sequence_parallel_is_enabled(): + return model_parallel_is_initialized() and get_sequence_parallel_world_size() > 1 + + +def get_sequence_parallel_world_size_or_one(): + return get_sequence_parallel_world_size() if model_parallel_is_initialized() else 1 + + +def sequence_parallel_is_initialized(): + """Check if sequence and data parallel groups are initialized.""" + if _SEQUENCE_PARALLEL_GROUP is None or \ + _DATA_PARALLEL_GROUP is None: + return False + return True + + +def sequence_data_parallel_is_initialized(): + """Check if sequence data parallel groups are initialized.""" + if _SEQUENCE_DATA_PARALLEL_GROUP is None: + return False + return True + + +def get_model_parallel_group(): + """Get the model parallel group the caller rank belongs to.""" + assert _MODEL_PARALLEL_GROUP is not None, 'model parallel group is not initialized' + return _MODEL_PARALLEL_GROUP + + +def get_tensor_model_parallel_group(check_initialized=True): + """Get the tensor model parallel group the caller rank belongs to.""" + if check_initialized: + assert _TENSOR_MODEL_PARALLEL_GROUP is not None, 'tensor model parallel group is not initialized' + return _TENSOR_MODEL_PARALLEL_GROUP + + +def get_pipeline_model_parallel_group(): + """Get the pipeline model parallel group the caller rank belongs to.""" + assert _PIPELINE_MODEL_PARALLEL_GROUP is not None, 'pipeline_model parallel group is not initialized' + return _PIPELINE_MODEL_PARALLEL_GROUP + + +def get_sequence_parallel_group(): + """Get the sequence parallel group the caller rank belongs to.""" + assert _SEQUENCE_PARALLEL_GROUP is not None, \ + 'sequence parallel group is not initialized' + return _SEQUENCE_PARALLEL_GROUP + + +def get_sequence_data_parallel_group(): + """Get the sequence parallel group the caller rank belongs to.""" + assert _SEQUENCE_DATA_PARALLEL_GROUP is not None, \ + 'sequence data parallel group is not initialized' + return _SEQUENCE_DATA_PARALLEL_GROUP + + +def get_data_parallel_group(): + """Get the data parallel group the caller rank belongs to.""" + assert _DATA_PARALLEL_GROUP is not None, 'data parallel group is not initialized' + return _DATA_PARALLEL_GROUP + + +def get_data_parallel_group_gloo(): + """Get the data parallel group-gloo the caller rank belongs to.""" + assert _DATA_PARALLEL_GROUP_GLOO is not None, \ + 'data parallel group-gloo is not initialized' + return _DATA_PARALLEL_GROUP_GLOO + + +def get_embedding_group(): + """Get the embedding group the caller rank belongs to.""" + assert _EMBEDDING_GROUP is not None, 'embedding group is not initialized' + return _EMBEDDING_GROUP + + +def get_position_embedding_group(): + """Get the position embedding group the caller rank belongs to.""" + assert _POSITION_EMBEDDING_GROUP is not None, 'position embedding group is not initialized' + return _POSITION_EMBEDDING_GROUP + + +def get_amax_reduction_group(): + """Get the FP8 amax reduction group the caller rank belongs to.""" + assert _AMAX_REDUCTION_GROUP is not None, \ + 'FP8 amax reduction group is not initialized' + return _AMAX_REDUCTION_GROUP + + +def set_tensor_model_parallel_world_size(world_size): + """Set the tensor model parallel size""" + global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE + _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = world_size + + +def set_sequence_parallel_world_size(world_size): + """Set the sequence parallel size""" + global _SEQUENCE_PARALLEL_WORLD_SIZE + _SEQUENCE_PARALLEL_WORLD_SIZE = world_size + + +def set_sequence_data_parallel_world_size(world_size): + """Set the sequence parallel size""" + global _SEQUENCE_DATA_PARALLEL_WORLD_SIZE + _SEQUENCE_DATA_PARALLEL_WORLD_SIZE = world_size + + +def set_pipeline_model_parallel_world_size(world_size): + """Set the pipeline model parallel size""" + global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE + _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = world_size + + +def set_virtual_pipeline_model_parallel_world_size(world_size): + """Set the virtual pipeline model parallel size""" + global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE + _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = world_size + + +def get_tensor_model_parallel_world_size(): + """Return world size for the tensor model parallel group.""" + global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE + if _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE is not None: + return _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE + return dist.get_world_size(group=get_tensor_model_parallel_group()) + + +def get_model_parallel_world_size(): + assert get_pipeline_model_parallel_world_size( + ) == 1, "legacy get_model_parallel_world_size is only supported if PP is disabled" + return get_tensor_model_parallel_world_size() + + +def get_sequence_parallel_world_size_or_none(): + if sequence_parallel_is_initialized(): + return get_sequence_parallel_world_size() + return None + + +def get_sequence_parallel_world_size(): + """Return world size for the sequence parallel group.""" + global _SEQUENCE_PARALLEL_WORLD_SIZE + if _SEQUENCE_PARALLEL_WORLD_SIZE is not None: + return _SEQUENCE_PARALLEL_WORLD_SIZE + return dist.get_world_size(group=get_sequence_parallel_group()) + + +def get_sequence_data_parallel_world_size(): + """Return world size for the sequence parallel group.""" + global _SEQUENCE_DATA_PARALLEL_WORLD_SIZE + if _SEQUENCE_DATA_PARALLEL_WORLD_SIZE is not None: + return _SEQUENCE_DATA_PARALLEL_WORLD_SIZE + return dist.get_world_size(group=get_sequence_data_parallel_group()) + + +def get_pipeline_model_parallel_world_size(): + """Return world size for the pipeline model parallel group.""" + global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE + if _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE is not None: + return _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE + return dist.get_world_size(group=get_pipeline_model_parallel_group()) + + +def set_tensor_model_parallel_rank(rank): + """Set tensor model parallel rank.""" + global _MPU_TENSOR_MODEL_PARALLEL_RANK + _MPU_TENSOR_MODEL_PARALLEL_RANK = rank + + +def get_model_parallel_rank(): + assert get_pipeline_model_parallel_world_size( + ) == 1, "legacy get_model_parallel_rank is only supported if PP is disabled" + return get_tensor_model_parallel_rank() + + +def set_sequence_parallel_rank(rank): + """Set sequence parallel rank.""" + global _SEQUENCE_PARALLEL_RANK + _SEQUENCE_PARALLEL_RANK = rank + + +def set_sequence_data_parallel_rank(rank): + """Set sequence parallel rank.""" + global _SEQUENCE_DATA_PARALLEL_RANK + _SEQUENCE_DATA_PARALLEL_RANK = rank + + +def set_pipeline_model_parallel_rank(rank): + """Set pipeline model parallel rank.""" + global _MPU_PIPELINE_MODEL_PARALLEL_RANK + _MPU_PIPELINE_MODEL_PARALLEL_RANK = rank + + +def set_pipeline_model_parallel_split_rank(rank): + """Set pipeline model parallel split rank.""" + global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK + _PIPELINE_MODEL_PARALLEL_SPLIT_RANK = rank + + +def get_tensor_model_parallel_rank(): + """Return my rank for the tensor model parallel group.""" + global _MPU_TENSOR_MODEL_PARALLEL_RANK + if _MPU_TENSOR_MODEL_PARALLEL_RANK is not None: + return _MPU_TENSOR_MODEL_PARALLEL_RANK + return dist.get_rank(group=get_tensor_model_parallel_group()) + + +def get_pipeline_model_parallel_rank(): + """Return my rank for the pipeline model parallel group.""" + global _MPU_PIPELINE_MODEL_PARALLEL_RANK + if _MPU_PIPELINE_MODEL_PARALLEL_RANK is not None: + return _MPU_PIPELINE_MODEL_PARALLEL_RANK + return dist.get_rank(group=get_pipeline_model_parallel_group()) + + +def get_pipeline_model_parallel_split_rank(): + """Return pipeline model parallel split rank.""" + global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK + return _PIPELINE_MODEL_PARALLEL_SPLIT_RANK + + +def get_sequence_parallel_rank(): + """Return my rank for the sequence parallel group.""" + global _SEQUENCE_PARALLEL_RANK + if _SEQUENCE_PARALLEL_RANK is not None: + return _SEQUENCE_PARALLEL_RANK + return dist.get_rank(group=get_sequence_parallel_group()) + + +def get_sequence_data_parallel_rank(): + """Return my rank for the sequence data parallel group.""" + global _SEQUENCE_DATA_PARALLEL_RANK + if _SEQUENCE_DATA_PARALLEL_RANK is not None: + return _SEQUENCE_DATA_PARALLEL_RANK + return dist.get_rank(group=get_sequence_data_parallel_group()) + + +def is_pipeline_first_stage(ignore_virtual=False): + """Return True if in the first pipeline model-parallel stage, False otherwise.""" + if not ignore_virtual: + if ( + get_virtual_pipeline_model_parallel_world_size() is not None + and get_virtual_pipeline_model_parallel_rank() != 0 + ): + return False + return get_pipeline_model_parallel_rank() == 0 + + +def is_pipeline_last_stage(ignore_virtual=False): + """Return True if in the last pipeline model-parallel stage, False otherwise.""" + if not ignore_virtual: + virtual_pipeline_model_parallel_world_size = get_virtual_pipeline_model_parallel_world_size() + if virtual_pipeline_model_parallel_world_size is not None and get_virtual_pipeline_model_parallel_rank() != ( + virtual_pipeline_model_parallel_world_size - 1 + ): + return False + return get_pipeline_model_parallel_rank() == (get_pipeline_model_parallel_world_size() - 1) + + +# def is_rank_in_embedding_group(ignore_virtual=False): +# """Return true if current rank is in embedding group, False otherwise.""" +# rank = dist.get_rank() +# global _EMBEDDING_GLOBAL_RANKS +# if ignore_virtual: +# return rank in _EMBEDDING_GLOBAL_RANKS +# if rank in _EMBEDDING_GLOBAL_RANKS: +# if rank == _EMBEDDING_GLOBAL_RANKS[0]: +# return is_pipeline_first_stage(ignore_virtual=False) +# elif rank == _EMBEDDING_GLOBAL_RANKS[-1]: +# return is_pipeline_last_stage(ignore_virtual=False) +# else: +# return True +# return False + + +# def is_rank_in_position_embedding_group(): +# """Return true if current rank is in position embedding group, False otherwise.""" +# rank = dist.get_rank() +# global _POSITION_EMBEDDING_GLOBAL_RANKS +# return rank in _POSITION_EMBEDDING_GLOBAL_RANKS + + +def is_pipeline_stage_before_split(rank=None): + """Return True if pipeline stage executes encoder block for a model + with both encoder and decoder.""" + if get_pipeline_model_parallel_world_size() == 1: + return True + if rank is None: + rank = get_pipeline_model_parallel_rank() + global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK + if _PIPELINE_MODEL_PARALLEL_SPLIT_RANK is None: + return True + if rank < _PIPELINE_MODEL_PARALLEL_SPLIT_RANK: + return True + return False + + +def is_pipeline_stage_after_split(rank=None): + """Return True if pipeline stage executes decoder block for a model + with both encoder and decoder.""" + if get_pipeline_model_parallel_world_size() == 1: + return True + if rank is None: + rank = get_pipeline_model_parallel_rank() + global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK + if _PIPELINE_MODEL_PARALLEL_SPLIT_RANK is None: + return True + if rank >= _PIPELINE_MODEL_PARALLEL_SPLIT_RANK: + return True + return False + + +def is_pipeline_stage_at_split(): + """Return true if pipeline stage executes decoder block and next + stage executes encoder block for a model with both encoder and + decoder.""" + rank = get_pipeline_model_parallel_rank() + return is_pipeline_stage_before_split(rank) and is_pipeline_stage_after_split(rank + 1) + + +def get_virtual_pipeline_model_parallel_rank(): + """Return the virtual pipeline-parallel rank.""" + global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK + return _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK + + +def set_virtual_pipeline_model_parallel_rank(rank): + """Set the virtual pipeline-parallel rank.""" + global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK + _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = rank + + +def get_virtual_pipeline_model_parallel_world_size(): + """Return the virtual pipeline-parallel world size.""" + global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE + return _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE + + +def get_tensor_model_parallel_src_rank(): + """Calculate the global rank corresponding to the first local rank + in the tensor model parallel group.""" + global_rank = dist.get_rank() + local_world_size = get_tensor_model_parallel_world_size() + return (global_rank // local_world_size) * local_world_size + + +def get_sequence_parallel_src_rank(): + """Calculate the global rank corresponding to the first local rank + in the sequence parallel group.""" + global_rank = dist.get_rank() + local_world_size = get_sequence_parallel_world_size() + return (global_rank // local_world_size) * local_world_size + + +def get_data_parallel_src_rank(): + """Calculate the global rank corresponding to the first local rank + in the data parallel group.""" + assert _DATA_PARALLEL_GLOBAL_RANKS is not None, "Data parallel group is not initialized" + return _DATA_PARALLEL_GLOBAL_RANKS[0] + + +def get_data_parallel_global_ranks(): + """Calculate the global rank corresponding to the first local rank + in the data parallel group.""" + assert _DATA_PARALLEL_GLOBAL_RANKS is not None, "Data parallel group is not initialized" + return _DATA_PARALLEL_GLOBAL_RANKS + + +def get_all_data_parallel_global_ranks(): + assert _DATA_PARALLEL_ALL_GROUP_RANKS is not None, "Data parallel group is not initialized" + return _DATA_PARALLEL_ALL_GROUP_RANKS + + +def get_pipeline_model_parallel_first_rank(): + """Return the global rank of the first process in the pipeline for the + current tensor parallel group""" + assert _PIPELINE_GLOBAL_RANKS is not None, "Pipeline parallel group is not initialized" + return _PIPELINE_GLOBAL_RANKS[0] + + +def get_pipeline_model_parallel_last_rank(): + """Return the global rank of the last process in the pipeline for the + current tensor parallel group""" + assert _PIPELINE_GLOBAL_RANKS is not None, "Pipeline parallel group is not initialized" + last_rank_local = get_pipeline_model_parallel_world_size() - 1 + return _PIPELINE_GLOBAL_RANKS[last_rank_local] + + +def get_pipeline_model_parallel_next_rank(): + """Return the global rank that follows the caller in the pipeline""" + assert _PIPELINE_GLOBAL_RANKS is not None, "Pipeline parallel group is not initialized" + rank_in_pipeline = get_pipeline_model_parallel_rank() + world_size = get_pipeline_model_parallel_world_size() + return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline + 1) % world_size] + + +def get_pipeline_model_parallel_prev_rank(): + """Return the global rank that preceeds the caller in the pipeline""" + assert _PIPELINE_GLOBAL_RANKS is not None, "Pipeline parallel group is not initialized" + rank_in_pipeline = get_pipeline_model_parallel_rank() + world_size = get_pipeline_model_parallel_world_size() + return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size] + + +def get_data_parallel_world_size(): + """Return world size for the data parallel group.""" + return dist.get_world_size(group=get_data_parallel_group()) + + +def get_data_parallel_rank(): + """Return my rank for the data parallel group.""" + return dist.get_rank(group=get_data_parallel_group()) + + +# def _set_global_memory_buffer(): +# """Initialize global buffer""" +# global _GLOBAL_MEMORY_BUFFER +# assert _GLOBAL_MEMORY_BUFFER is None, 'global memory buffer is already initialized' +# _GLOBAL_MEMORY_BUFFER = GlobalMemoryBuffer() + + +# def get_global_memory_buffer(): +# """Return the global GlobalMemoryBuffer object""" +# assert _GLOBAL_MEMORY_BUFFER is not None, 'global memory buffer is not initialized' +# return _GLOBAL_MEMORY_BUFFER + + +# def destroy_global_memory_buffer(): +# """Sets the global memory buffer to None""" +# global _GLOBAL_MEMORY_BUFFER +# _GLOBAL_MEMORY_BUFFER = None + + +def destroy_model_parallel(): + """Set the groups to none.""" + global _MODEL_PARALLEL_GROUP + _MODEL_PARALLEL_GROUP = None + global _TENSOR_MODEL_PARALLEL_GROUP + _TENSOR_MODEL_PARALLEL_GROUP = None + global _PIPELINE_MODEL_PARALLEL_GROUP + _PIPELINE_MODEL_PARALLEL_GROUP = None + global _DATA_PARALLEL_GROUP + _DATA_PARALLEL_GROUP = None + global _SEQUENCE_PARALLEL_GROUP + _SEQUENCE_PARALLEL_GROUP = None + global _SEQUENCE_DATA_PARALLEL_GROUP + _SEQUENCE_DATA_PARALLEL_GROUP = None + global _EMBEDDING_GROUP + _EMBEDDING_GROUP = None + global _POSITION_EMBEDDING_GROUP + _POSITION_EMBEDDING_GROUP = None + global _AMAX_REDUCTION_GROUP + _AMAX_REDUCTION_GROUP = None + global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK + _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None + global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE + _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None + global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE + _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None + global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE + _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None + global _MPU_TENSOR_MODEL_PARALLEL_RANK + _MPU_TENSOR_MODEL_PARALLEL_RANK = None + global _MPU_PIPELINE_MODEL_PARALLEL_RANK + _MPU_PIPELINE_MODEL_PARALLEL_RANK = None + # global _GLOBAL_MEMORY_BUFFER + # _GLOBAL_MEMORY_BUFFER = None From ad2b23d6b4f66b8c36be8a880666337883d16030 Mon Sep 17 00:00:00 2001 From: Elsa Date: Wed, 3 Jul 2024 19:45:58 +0800 Subject: [PATCH 2/3] Fix ruff error --- src/accelerate/accelerator.py | 2 +- src/accelerate/data_loader.py | 5 ++--- src/accelerate/utils/parallel_state.py | 11 ++++++----- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index 4fc8f9821d5..683ac17535f 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -103,10 +103,10 @@ save_fsdp_optimizer, wait_for_everyone, ) +from .utils import parallel_state as mpu from .utils.constants import FSDP_PYTORCH_VERSION, PROFILE_PATTERN_NAME from .utils.modeling import get_state_dict_offloaded_model from .utils.other import is_compiled_module -from .utils import parallel_state as mpu if is_deepspeed_available(): diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index f9f45ccf04b..a0396b16e51 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -17,7 +17,7 @@ from typing import Callable, List, Optional, Union import torch -from torch.utils.data import BatchSampler, DataLoader, IterableDataset, RandomSampler +from torch.utils.data import BatchSampler, DataLoader, DistributedSampler, IterableDataset, RandomSampler from .logging import get_logger from .state import AcceleratorState, DistributedType, GradientState, is_torch_xla_available @@ -34,7 +34,6 @@ slice_tensors, synchronize_rng_states, ) -from torch.utils.data import DistributedSampler logger = get_logger(__name__) @@ -942,7 +941,7 @@ def prepare_data_loader( generator = torch.Generator().manual_seed(42) dataloader.generator = generator dataloader.sampler.generator = generator - + is_distributed_sampler = isinstance( dataloader.sampler.sampler if sampler_is_batch_sampler else dataloader.sampler, DistributedSampler diff --git a/src/accelerate/utils/parallel_state.py b/src/accelerate/utils/parallel_state.py index b7000cba489..f47f354329b 100644 --- a/src/accelerate/utils/parallel_state.py +++ b/src/accelerate/utils/parallel_state.py @@ -1,4 +1,6 @@ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +from torch import distributed as dist + """Model and data parallel groups.""" @@ -9,9 +11,8 @@ # Intra-layer model parallel group that the current rank belongs to. -import torch -from torch import distributed as dist -from typing import Optional + + _TENSOR_MODEL_PARALLEL_GROUP = None # Inter-layer model parallel group that the current rank belongs to. _PIPELINE_MODEL_PARALLEL_GROUP = None @@ -165,7 +166,7 @@ def initialize_model_parallel( num_tensor_model_parallel_groups: int = world_size // tensor_model_parallel_size num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size - num_data_parallel_groups: int = world_size // data_parallel_size + # num_data_parallel_groups: int = world_size // data_parallel_size num_sequence_parallel_groups: int = world_size // sequence_parallel_size num_sequence_data_parallel_groups: int = world_size // sequence_parallel_size // data_parallel_size @@ -211,7 +212,7 @@ def initialize_model_parallel( _DATA_PARALLEL_GROUP = group _DATA_PARALLEL_GROUP_GLOO = group_gloo _DATA_PARALLEL_GLOBAL_RANKS = ranks - + global _DATA_PARALLEL_ALL_GROUP_RANKS _DATA_PARALLEL_ALL_GROUP_RANKS = all_data_parallel_group_ranks From 2e33e77900d51c7ebd6cdfc290e7fdc821ad0aac Mon Sep 17 00:00:00 2001 From: Elsa Date: Mon, 15 Jul 2024 22:55:53 +0800 Subject: [PATCH 3/3] Run make style, make quality --- src/accelerate/data_loader.py | 9 +- src/accelerate/utils/parallel_state.py | 159 +++++++++++-------------- 2 files changed, 74 insertions(+), 94 deletions(-) diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index a0396b16e51..aee96601ce8 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -943,15 +943,18 @@ def prepare_data_loader( dataloader.sampler.generator = generator is_distributed_sampler = isinstance( - dataloader.sampler.sampler if sampler_is_batch_sampler else dataloader.sampler, - DistributedSampler + dataloader.sampler.sampler if sampler_is_batch_sampler else dataloader.sampler, DistributedSampler ) if is_distributed_sampler and split_batches: raise ValueError("Using `split_batches=True` with a `DistributedSampler` is not supported.") # No change if no multiprocess - if (num_processes != 1 or state.distributed_type == DistributedType.MEGATRON_LM) and not dispatch_batches and not is_distributed_sampler: + if ( + (num_processes != 1 or state.distributed_type == DistributedType.MEGATRON_LM) + and not dispatch_batches + and not is_distributed_sampler + ): if isinstance(new_dataset, IterableDataset): if getattr(dataloader.dataset, "generator", None) is not None: synchronized_generator = dataloader.dataset.generator diff --git a/src/accelerate/utils/parallel_state.py b/src/accelerate/utils/parallel_state.py index f47f354329b..7e74eaaa008 100644 --- a/src/accelerate/utils/parallel_state.py +++ b/src/accelerate/utils/parallel_state.py @@ -5,8 +5,7 @@ """Model and data parallel groups.""" """ -mpu: Optional: A model parallelism unit object that implements -get_{model,data}_parallel_{rank,group,world_size}() +mpu: Optional: A model parallelism unit object that implements get_{model,data}_parallel_{rank,group,world_size}() """ @@ -84,65 +83,50 @@ def initialize_model_parallel( The number of GPUs to split individual tensors across. pipeline_model_parallel_size (int, default = 1): - The number of tensor parallel GPU groups to split the - Transformer layers across. For example, if - tensor_model_parallel_size is 4 and - pipeline_model_parallel_size is 2, the model will be split - into 2 groups of 4 GPUs. + The number of tensor parallel GPU groups to split the Transformer layers across. For example, if + tensor_model_parallel_size is 4 and pipeline_model_parallel_size is 2, the model will be split into 2 + groups of 4 GPUs. virtual_pipeline_model_parallel_size (int, optional): - The number of stages that each pipeline group will have, - interleaving as necessary. If None, no interleaving is - performed. For example, if tensor_model_parallel_size is 1, - pipeline_model_parallel_size is 4, - virtual_pipeline_model_parallel_size is 2, and there are - 16 transformer layers in the model, the model will be - split into 8 stages with two layers each and each GPU - would get 2 stages as such (layer number starting with 1): - - GPU 0: [1, 2] [9, 10] - GPU 1: [3, 4] [11, 12] - GPU 2: [5, 6] [13, 14] - GPU 3: [7, 8] [15, 16] + The number of stages that each pipeline group will have, interleaving as necessary. If None, no + interleaving is performed. For example, if tensor_model_parallel_size is 1, pipeline_model_parallel_size is + 4, virtual_pipeline_model_parallel_size is 2, and there are 16 transformer layers in the model, the model + will be split into 8 stages with two layers each and each GPU would get 2 stages as such (layer number + starting with 1): + + GPU 0: [1, 2] [9, 10] GPU 1: [3, 4] [11, 12] GPU 2: [5, 6] [13, 14] GPU 3: [7, 8] [15, 16] pipeline_model_parallel_split_rank (int, optional): - For models with both an encoder and decoder, the rank in - pipeline to switch between encoder and decoder (i.e. the - first rank of the decoder). This allows the user to set - the pipeline parallel size of the encoder and decoder - independently. For example, if - pipeline_model_parallel_size is 8 and - pipeline_model_parallel_split_rank is 3, then ranks 0-2 - will be the encoder and ranks 3-7 will be the decoder. + For models with both an encoder and decoder, the rank in pipeline to switch between encoder and decoder + (i.e. the first rank of the decoder). This allows the user to set the pipeline parallel size of the encoder + and decoder independently. For example, if pipeline_model_parallel_size is 8 and + pipeline_model_parallel_split_rank is 3, then ranks 0-2 will be the encoder and ranks 3-7 will be the + decoder. use_fp8 (bool, default = False): - Construct GPU groups needed for FP8 training, namely for - amax reduction across the product of the data-parallel and - tensor-parallel groups. - - Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we - use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize - the model pipeline. The present function will - create 8 tensor model-parallel groups, 4 pipeline model-parallel groups - and 8 data-parallel groups as: + Construct GPU groups needed for FP8 training, namely for amax reduction across the product of the + data-parallel and tensor-parallel groups. + + Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we use 2 GPUs to parallelize the model tensor, and 4 + GPUs to parallelize the model pipeline. The present function will create 8 tensor model-parallel groups, 4 pipeline + model-parallel groups and 8 data-parallel groups as: 8 data_parallel groups: [g0, g2], [g1, g3], [g4, g6], [g5, g7], [g8, g10], [g9, g11], [g12, g14], [g13, g15] 8 tensor model-parallel groups: [g0, g1], [g2, g3], [g4, g5], [g6, g7], [g8, g9], [g10, g11], [g12, g13], [g14, g15] 4 pipeline model-parallel groups: [g0, g4, g8, g12], [g1, g5, g9, g13], [g2, g6, g10, g14], [g3, g7, g11, g15] - Note that for efficiency, the caller should make sure adjacent ranks - are on the same DGX box. For example if we are using 2 DGX-1 boxes - with a total of 16 GPUs, rank 0 to 7 belong to the first box and - ranks 8 to 15 belong to the second box. + Note that for efficiency, the caller should make sure adjacent ranks are on the same DGX box. For example if we are + using 2 DGX-1 boxes with a total of 16 GPUs, rank 0 to 7 belong to the first box and ranks 8 to 15 belong to the + second box. """ # Get world size and rank. Ensure some consistencies. assert dist.is_initialized() world_size: int = dist.get_world_size() - assert tensor_model_parallel_size == 1, 'tensor model parallel size should be 1' - assert pipeline_model_parallel_size == 1, 'pipeline model parallel size should be 1' + assert tensor_model_parallel_size == 1, "tensor model parallel size should be 1" + assert pipeline_model_parallel_size == 1, "pipeline model parallel size should be 1" if world_size % (tensor_model_parallel_size * pipeline_model_parallel_size) != 0: raise RuntimeError( @@ -152,8 +136,9 @@ def initialize_model_parallel( enable_ds_sequence_parallel = sequence_parallel_size > 1 if enable_ds_sequence_parallel: - assert tensor_model_parallel_size == 1 and pipeline_model_parallel_size == 1, \ - 'DeepSpeed\'s sequence parallel does not work with tensor parallel or pipeline parallel' + assert ( + tensor_model_parallel_size == 1 and pipeline_model_parallel_size == 1 + ), "DeepSpeed's sequence parallel does not work with tensor parallel or pipeline parallel" if world_size % sequence_parallel_size != 0: raise RuntimeError( @@ -161,7 +146,8 @@ def initialize_model_parallel( ) data_parallel_size: int = world_size // ( - tensor_model_parallel_size * pipeline_model_parallel_size * sequence_parallel_size) + tensor_model_parallel_size * pipeline_model_parallel_size * sequence_parallel_size + ) sequence_data_parallel_size: int = sequence_parallel_size * data_parallel_size num_tensor_model_parallel_groups: int = world_size // tensor_model_parallel_size @@ -189,7 +175,7 @@ def initialize_model_parallel( global _DATA_PARALLEL_GROUP global _DATA_PARALLEL_GROUP_GLOO global _DATA_PARALLEL_GLOBAL_RANKS - assert _DATA_PARALLEL_GROUP is None, 'data parallel group is already initialized' + assert _DATA_PARALLEL_GROUP is None, "data parallel group is already initialized" all_data_parallel_group_ranks = [] for i in range(pipeline_model_parallel_size): start_rank = i * num_pipeline_model_parallel_groups @@ -218,25 +204,21 @@ def initialize_model_parallel( # Build the sequence parallel groups. global _SEQUENCE_PARALLEL_GROUP - assert _SEQUENCE_PARALLEL_GROUP is None, \ - 'sequence parallel group is already initialized' + assert _SEQUENCE_PARALLEL_GROUP is None, "sequence parallel group is already initialized" for i in range(num_sequence_parallel_groups): - ranks = range(i * sequence_parallel_size, - (i + 1) * sequence_parallel_size) + ranks = range(i * sequence_parallel_size, (i + 1) * sequence_parallel_size) group = dist.new_group(ranks) if rank in ranks: _SEQUENCE_PARALLEL_GROUP = group # Build the sequence data parallel groups. global _SEQUENCE_DATA_PARALLEL_GROUP - assert _SEQUENCE_DATA_PARALLEL_GROUP is None, \ - 'sequence data parallel group is already initialized' + assert _SEQUENCE_DATA_PARALLEL_GROUP is None, "sequence data parallel group is already initialized" all_data_sequence_parallel_group_ranks = [] if enable_ds_sequence_parallel: for i in range(num_sequence_data_parallel_groups): - ranks = range(i * sequence_data_parallel_size, - (i + 1) * sequence_data_parallel_size) + ranks = range(i * sequence_data_parallel_size, (i + 1) * sequence_data_parallel_size) group = dist.new_group(ranks) all_data_sequence_parallel_group_ranks.append(list(ranks)) if rank in ranks: @@ -246,22 +228,22 @@ def initialize_model_parallel( # Build the model-parallel groups. global _MODEL_PARALLEL_GROUP - assert _MODEL_PARALLEL_GROUP is None, 'model parallel group is already initialized' + assert _MODEL_PARALLEL_GROUP is None, "model parallel group is already initialized" num_model_parallel_groups = sequence_data_parallel_size if enable_ds_sequence_parallel else data_parallel_size - model_parallel_group_ranks = all_data_sequence_parallel_group_ranks if enable_ds_sequence_parallel else all_data_parallel_group_ranks + model_parallel_group_ranks = ( + all_data_sequence_parallel_group_ranks if enable_ds_sequence_parallel else all_data_parallel_group_ranks + ) for i in range(num_model_parallel_groups): - ranks = [parallel_group_ranks[i] - for parallel_group_ranks in model_parallel_group_ranks] + ranks = [parallel_group_ranks[i] for parallel_group_ranks in model_parallel_group_ranks] group = dist.new_group(ranks) if rank in ranks: _MODEL_PARALLEL_GROUP = group # Build the tensor model-parallel groups. global _TENSOR_MODEL_PARALLEL_GROUP - assert _TENSOR_MODEL_PARALLEL_GROUP is None, 'tensor model parallel group is already initialized' + assert _TENSOR_MODEL_PARALLEL_GROUP is None, "tensor model parallel group is already initialized" for i in range(num_tensor_model_parallel_groups): - ranks = range(i * tensor_model_parallel_size, - (i + 1) * tensor_model_parallel_size) + ranks = range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size) group = dist.new_group(ranks) if rank in ranks: _TENSOR_MODEL_PARALLEL_GROUP = group @@ -270,13 +252,13 @@ def initialize_model_parallel( # (first and last rank in each pipeline model-parallel group). global _PIPELINE_MODEL_PARALLEL_GROUP global _PIPELINE_GLOBAL_RANKS - assert _PIPELINE_MODEL_PARALLEL_GROUP is None, 'pipeline model parallel group is already initialized' + assert _PIPELINE_MODEL_PARALLEL_GROUP is None, "pipeline model parallel group is already initialized" global _EMBEDDING_GROUP global _EMBEDDING_GLOBAL_RANKS - assert _EMBEDDING_GROUP is None, 'embedding group is already initialized' + assert _EMBEDDING_GROUP is None, "embedding group is already initialized" global _POSITION_EMBEDDING_GROUP global _POSITION_EMBEDDING_GLOBAL_RANKS - assert _POSITION_EMBEDDING_GROUP is None, 'position embedding group is already initialized' + assert _POSITION_EMBEDDING_GROUP is None, "position embedding group is already initialized" for i in range(num_pipeline_model_parallel_groups): ranks = range(i, world_size, num_pipeline_model_parallel_groups) group = dist.new_group(ranks) @@ -313,8 +295,7 @@ def initialize_model_parallel( # Build the FP8 groups. global _AMAX_REDUCTION_GROUP - assert _AMAX_REDUCTION_GROUP is None, \ - 'FP8 amax reduction group is already initialized' + assert _AMAX_REDUCTION_GROUP is None, "FP8 amax reduction group is already initialized" # if use_fp8: # amax_group_size: int = tensor_model_parallel_size * data_parallel_size # num_amax_groups: int = world_size // amax_group_size @@ -355,8 +336,7 @@ def get_sequence_parallel_world_size_or_one(): def sequence_parallel_is_initialized(): """Check if sequence and data parallel groups are initialized.""" - if _SEQUENCE_PARALLEL_GROUP is None or \ - _DATA_PARALLEL_GROUP is None: + if _SEQUENCE_PARALLEL_GROUP is None or _DATA_PARALLEL_GROUP is None: return False return True @@ -370,66 +350,62 @@ def sequence_data_parallel_is_initialized(): def get_model_parallel_group(): """Get the model parallel group the caller rank belongs to.""" - assert _MODEL_PARALLEL_GROUP is not None, 'model parallel group is not initialized' + assert _MODEL_PARALLEL_GROUP is not None, "model parallel group is not initialized" return _MODEL_PARALLEL_GROUP def get_tensor_model_parallel_group(check_initialized=True): """Get the tensor model parallel group the caller rank belongs to.""" if check_initialized: - assert _TENSOR_MODEL_PARALLEL_GROUP is not None, 'tensor model parallel group is not initialized' + assert _TENSOR_MODEL_PARALLEL_GROUP is not None, "tensor model parallel group is not initialized" return _TENSOR_MODEL_PARALLEL_GROUP def get_pipeline_model_parallel_group(): """Get the pipeline model parallel group the caller rank belongs to.""" - assert _PIPELINE_MODEL_PARALLEL_GROUP is not None, 'pipeline_model parallel group is not initialized' + assert _PIPELINE_MODEL_PARALLEL_GROUP is not None, "pipeline_model parallel group is not initialized" return _PIPELINE_MODEL_PARALLEL_GROUP def get_sequence_parallel_group(): """Get the sequence parallel group the caller rank belongs to.""" - assert _SEQUENCE_PARALLEL_GROUP is not None, \ - 'sequence parallel group is not initialized' + assert _SEQUENCE_PARALLEL_GROUP is not None, "sequence parallel group is not initialized" return _SEQUENCE_PARALLEL_GROUP def get_sequence_data_parallel_group(): """Get the sequence parallel group the caller rank belongs to.""" - assert _SEQUENCE_DATA_PARALLEL_GROUP is not None, \ - 'sequence data parallel group is not initialized' + assert _SEQUENCE_DATA_PARALLEL_GROUP is not None, "sequence data parallel group is not initialized" return _SEQUENCE_DATA_PARALLEL_GROUP def get_data_parallel_group(): """Get the data parallel group the caller rank belongs to.""" - assert _DATA_PARALLEL_GROUP is not None, 'data parallel group is not initialized' + assert _DATA_PARALLEL_GROUP is not None, "data parallel group is not initialized" return _DATA_PARALLEL_GROUP def get_data_parallel_group_gloo(): """Get the data parallel group-gloo the caller rank belongs to.""" - assert _DATA_PARALLEL_GROUP_GLOO is not None, \ - 'data parallel group-gloo is not initialized' + assert _DATA_PARALLEL_GROUP_GLOO is not None, "data parallel group-gloo is not initialized" return _DATA_PARALLEL_GROUP_GLOO def get_embedding_group(): """Get the embedding group the caller rank belongs to.""" - assert _EMBEDDING_GROUP is not None, 'embedding group is not initialized' + assert _EMBEDDING_GROUP is not None, "embedding group is not initialized" return _EMBEDDING_GROUP def get_position_embedding_group(): """Get the position embedding group the caller rank belongs to.""" - assert _POSITION_EMBEDDING_GROUP is not None, 'position embedding group is not initialized' + assert _POSITION_EMBEDDING_GROUP is not None, "position embedding group is not initialized" return _POSITION_EMBEDDING_GROUP def get_amax_reduction_group(): """Get the FP8 amax reduction group the caller rank belongs to.""" - assert _AMAX_REDUCTION_GROUP is not None, \ - 'FP8 amax reduction group is not initialized' + assert _AMAX_REDUCTION_GROUP is not None, "FP8 amax reduction group is not initialized" return _AMAX_REDUCTION_GROUP @@ -440,13 +416,13 @@ def set_tensor_model_parallel_world_size(world_size): def set_sequence_parallel_world_size(world_size): - """Set the sequence parallel size""" + """Set the sequence parallel size""" global _SEQUENCE_PARALLEL_WORLD_SIZE _SEQUENCE_PARALLEL_WORLD_SIZE = world_size def set_sequence_data_parallel_world_size(world_size): - """Set the sequence parallel size""" + """Set the sequence parallel size""" global _SEQUENCE_DATA_PARALLEL_WORLD_SIZE _SEQUENCE_DATA_PARALLEL_WORLD_SIZE = world_size @@ -472,8 +448,9 @@ def get_tensor_model_parallel_world_size(): def get_model_parallel_world_size(): - assert get_pipeline_model_parallel_world_size( - ) == 1, "legacy get_model_parallel_world_size is only supported if PP is disabled" + assert ( + get_pipeline_model_parallel_world_size() == 1 + ), "legacy get_model_parallel_world_size is only supported if PP is disabled" return get_tensor_model_parallel_world_size() @@ -514,8 +491,9 @@ def set_tensor_model_parallel_rank(rank): def get_model_parallel_rank(): - assert get_pipeline_model_parallel_world_size( - ) == 1, "legacy get_model_parallel_rank is only supported if PP is disabled" + assert ( + get_pipeline_model_parallel_world_size() == 1 + ), "legacy get_model_parallel_rank is only supported if PP is disabled" return get_tensor_model_parallel_rank() @@ -658,8 +636,7 @@ def is_pipeline_stage_after_split(rank=None): def is_pipeline_stage_at_split(): """Return true if pipeline stage executes decoder block and next - stage executes encoder block for a model with both encoder and - decoder.""" + stage executes encoder block for a model with both encoder and decoder.""" rank = get_pipeline_model_parallel_rank() return is_pipeline_stage_before_split(rank) and is_pipeline_stage_after_split(rank + 1)