Skip to content

Commit

Permalink
fixed can't do the forward pass in parallelized transformers's embedd…
Browse files Browse the repository at this point in the history
…ing module
  • Loading branch information
xrsrke committed Aug 31, 2023
1 parent cd24b87 commit d2dc575
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 6 deletions.
13 changes: 8 additions & 5 deletions pipegoose/nn/tensor_parallel/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,17 @@ def _split_weight(self):
vocab_size = self.module.weight.shape[0]
vocab_start_idx, vocab_end_idx = VocabUtility.get_vocab_range_from_global_vocab_size(world_size, rank, vocab_size)
weight_chunks = torch.chunk(self.module.weight, world_size, dim=0)
self.module.weight.data = weight_chunks[rank]

self.module.weight.data = weight_chunks[rank]
self.module.__class__ = ParallelEmbedding

_update_model_arguments(module=self.module, vocab_start_idx=vocab_start_idx, vocab_end_idx=vocab_end_idx)
_update_model_arguments(
module=self.module,
parallel_context=self.parallel_context,
vocab_start_idx=vocab_start_idx,
vocab_end_idx=vocab_end_idx,
world_size=world_size,
)

def _resize_vocab_size(self):
"""Make vocab size divisible by world size."""
Expand All @@ -74,9 +80,6 @@ def _resize_vocab_size(self):

self.module.weight.data = new_embeddings

# def _is_text_embedding(self, module: nn.Module) -> bool:
# return True if module is self.module.get_input_embeddings() else False


class ParallelizeLayerNorm(ParallelizeModule):
pass
Expand Down
9 changes: 8 additions & 1 deletion tests/nn/tensor_parallel/test_parallelize.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest
import torch
from transformers import AutoModel

from pipegoose.distributed.parallel_context import ParallelContext
Expand Down Expand Up @@ -27,7 +28,7 @@ def init_parallel_context(rank, world_size, port, tensor_parallel_size, pipeline


def run_parallelize_embedding(
rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size, embedding
rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size, embedding, input, output
):
parallel_context = init_parallel_context(
rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size
Expand All @@ -48,16 +49,20 @@ def get_new_embedding_size(vocab_size):
vocab_start_idx, vocab_end_idx = VocabUtility.get_vocab_range_from_global_vocab_size(world_size, rank, new_vocab_size)

parallelized_embedding = ParallelizeEmbedding(embedding, parallel_context).parallelize()
parallel_output = parallelized_embedding(input)

assert parallelized_embedding.vocab_start_idx == vocab_start_idx
assert parallelized_embedding.vocab_end_idx == vocab_end_idx
assert parallelized_embedding.weight.shape == (new_partition_size, embedding_dim)
assert torch.allclose(parallel_output, output)


@pytest.mark.parametrize("tensor_parallel_size", [1, 2])
def test_parallelize_embedding(tensor_parallel_size):
model = AutoModel.from_pretrained("gpt2")
input = torch.arange(0, 10)
embedding = model.get_input_embeddings()
output = embedding(input)

spawn(
run_parallelize_embedding,
Expand All @@ -66,4 +71,6 @@ def test_parallelize_embedding(tensor_parallel_size):
pipeline_parallel_size=1,
data_parallel_size=1,
embedding=embedding,
input=input.detach(),
output=output.detach(),
)

0 comments on commit d2dc575

Please sign in to comment.