Skip to content

Commit

Permalink
add is_first_rank and is_last_rank
Browse files Browse the repository at this point in the history
  • Loading branch information
xrsrke committed Sep 4, 2023
1 parent 79bb372 commit 9ebeb47
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 0 deletions.
9 changes: 9 additions & 0 deletions pipegoose/distributed/parallel_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,15 @@ def get_prev_local_rank(self, rank, parallel_mode: ParallelMode) -> int:
world_size = self.get_world_size(parallel_mode)
return (rank - 1) % world_size

def is_first_rank(self, parallel_mode: ParallelMode) -> bool:
local_rank = self.get_local_rank(parallel_mode)
return local_rank == 0

def is_last_rank(self, parallel_mode: ParallelMode) -> bool:
local_rank = self.get_local_rank(parallel_mode)
world_size = self.get_world_size(parallel_mode)
return local_rank == world_size - 1

def destroy(self):
assert self.is_initialized(ParallelMode.GLOBAL), "Global group must be initialized before destroying."
for mode, group in self._groups.items():
Expand Down
3 changes: 3 additions & 0 deletions tests/distributed/test_parallel_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,9 @@ def init_parallel_context(
prev_local_rank = parallel_context.get_prev_local_rank(local_rank, parallel_mode)
assert prev_local_rank == LOCAL_RANK_TO_PREV_RANK[world_size][parallel_mode][local_rank]

assert parallel_context.is_first_rank(parallel_mode) == (local_rank == 0)
assert parallel_context.is_last_rank(parallel_mode) == (local_rank == parallel_context.get_world_size(parallel_mode) - 1)

parallel_context.destroy()

if pipeline_parallel_size > 1:
Expand Down

0 comments on commit 9ebeb47

Please sign in to comment.