Skip to content

Commit

Permalink
add docs
Browse files Browse the repository at this point in the history
  • Loading branch information
xrsrke committed Aug 31, 2023
1 parent d2dc575 commit 90d479c
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions pipegoose/nn/tensor_parallel/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def deparallelize(self):
class ParallelizeEmbedding(ParallelizeModule):
# TODO: refactor to staticmethod
def parallelize(self) -> nn.Module:
"""Parallelize nn.Embedding module."""
assert isinstance(self.module, nn.Embedding), "only parallelize nn.Embedding"
self._resize_vocab_size()
self._split_weight()
Expand All @@ -48,6 +49,7 @@ def deparallelize(self):
pass

def _split_weight(self):
"""Split weight into chunks and assign to each process."""
world_size = self.parallel_context.get_world_size(ParallelMode.TENSOR)
rank = self.parallel_context.get_local_rank(ParallelMode.TENSOR)

Expand Down

0 comments on commit 90d479c

Please sign in to comment.