diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index 7eb3e65d099..0eea266c962 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -102,10 +102,10 @@ save_fsdp_optimizer, wait_for_everyone, ) +from .utils import parallel_state as mpu from .utils.constants import FSDP_PYTORCH_VERSION from .utils.modeling import get_state_dict_offloaded_model from .utils.other import is_compiled_module -from .utils import parallel_state as mpu if is_deepspeed_available(): diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index f9f45ccf04b..a0396b16e51 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -17,7 +17,7 @@ from typing import Callable, List, Optional, Union import torch -from torch.utils.data import BatchSampler, DataLoader, IterableDataset, RandomSampler +from torch.utils.data import BatchSampler, DataLoader, DistributedSampler, IterableDataset, RandomSampler from .logging import get_logger from .state import AcceleratorState, DistributedType, GradientState, is_torch_xla_available @@ -34,7 +34,6 @@ slice_tensors, synchronize_rng_states, ) -from torch.utils.data import DistributedSampler logger = get_logger(__name__) @@ -942,7 +941,7 @@ def prepare_data_loader( generator = torch.Generator().manual_seed(42) dataloader.generator = generator dataloader.sampler.generator = generator - + is_distributed_sampler = isinstance( dataloader.sampler.sampler if sampler_is_batch_sampler else dataloader.sampler, DistributedSampler diff --git a/src/accelerate/utils/parallel_state.py b/src/accelerate/utils/parallel_state.py index b7000cba489..f47f354329b 100644 --- a/src/accelerate/utils/parallel_state.py +++ b/src/accelerate/utils/parallel_state.py @@ -1,4 +1,6 @@ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +from torch import distributed as dist + """Model and data parallel groups.""" @@ -9,9 +11,8 @@ # Intra-layer model parallel group that the current rank belongs to. -import torch -from torch import distributed as dist -from typing import Optional + + _TENSOR_MODEL_PARALLEL_GROUP = None # Inter-layer model parallel group that the current rank belongs to. _PIPELINE_MODEL_PARALLEL_GROUP = None @@ -165,7 +166,7 @@ def initialize_model_parallel( num_tensor_model_parallel_groups: int = world_size // tensor_model_parallel_size num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size - num_data_parallel_groups: int = world_size // data_parallel_size + # num_data_parallel_groups: int = world_size // data_parallel_size num_sequence_parallel_groups: int = world_size // sequence_parallel_size num_sequence_data_parallel_groups: int = world_size // sequence_parallel_size // data_parallel_size @@ -211,7 +212,7 @@ def initialize_model_parallel( _DATA_PARALLEL_GROUP = group _DATA_PARALLEL_GROUP_GLOO = group_gloo _DATA_PARALLEL_GLOBAL_RANKS = ranks - + global _DATA_PARALLEL_ALL_GROUP_RANKS _DATA_PARALLEL_ALL_GROUP_RANKS = all_data_parallel_group_ranks