Skip to content

Commit

Permalink
Fix ruff error
Browse files Browse the repository at this point in the history
  • Loading branch information
zeyugao committed Jul 3, 2024
1 parent a201735 commit 13dad66
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 9 deletions.
2 changes: 1 addition & 1 deletion src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,10 @@
save_fsdp_optimizer,
wait_for_everyone,
)
from .utils import parallel_state as mpu
from .utils.constants import FSDP_PYTORCH_VERSION, PROFILE_PATTERN_NAME
from .utils.modeling import get_state_dict_offloaded_model
from .utils.other import is_compiled_module
from .utils import parallel_state as mpu


if is_deepspeed_available():
Expand Down
5 changes: 2 additions & 3 deletions src/accelerate/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from typing import Callable, List, Optional, Union

import torch
from torch.utils.data import BatchSampler, DataLoader, IterableDataset, RandomSampler
from torch.utils.data import BatchSampler, DataLoader, DistributedSampler, IterableDataset, RandomSampler

from .logging import get_logger
from .state import AcceleratorState, DistributedType, GradientState, is_torch_xla_available
Expand All @@ -34,7 +34,6 @@
slice_tensors,
synchronize_rng_states,
)
from torch.utils.data import DistributedSampler


logger = get_logger(__name__)
Expand Down Expand Up @@ -942,7 +941,7 @@ def prepare_data_loader(
generator = torch.Generator().manual_seed(42)
dataloader.generator = generator
dataloader.sampler.generator = generator

is_distributed_sampler = isinstance(
dataloader.sampler.sampler if sampler_is_batch_sampler else dataloader.sampler,
DistributedSampler
Expand Down
11 changes: 6 additions & 5 deletions src/accelerate/utils/parallel_state.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
from torch import distributed as dist


"""Model and data parallel groups."""

Expand All @@ -9,9 +11,8 @@


# Intra-layer model parallel group that the current rank belongs to.
import torch
from torch import distributed as dist
from typing import Optional


_TENSOR_MODEL_PARALLEL_GROUP = None
# Inter-layer model parallel group that the current rank belongs to.
_PIPELINE_MODEL_PARALLEL_GROUP = None
Expand Down Expand Up @@ -165,7 +166,7 @@ def initialize_model_parallel(

num_tensor_model_parallel_groups: int = world_size // tensor_model_parallel_size
num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size
num_data_parallel_groups: int = world_size // data_parallel_size
# num_data_parallel_groups: int = world_size // data_parallel_size
num_sequence_parallel_groups: int = world_size // sequence_parallel_size
num_sequence_data_parallel_groups: int = world_size // sequence_parallel_size // data_parallel_size

Expand Down Expand Up @@ -211,7 +212,7 @@ def initialize_model_parallel(
_DATA_PARALLEL_GROUP = group
_DATA_PARALLEL_GROUP_GLOO = group_gloo
_DATA_PARALLEL_GLOBAL_RANKS = ranks

global _DATA_PARALLEL_ALL_GROUP_RANKS
_DATA_PARALLEL_ALL_GROUP_RANKS = all_data_parallel_group_ranks

Expand Down

0 comments on commit 13dad66

Please sign in to comment.