Skip to content

Commit

Permalink
add parallelize transformer's first linear in an MLP
Browse files Browse the repository at this point in the history
  • Loading branch information
xrsrke committed Aug 31, 2023
1 parent 90d479c commit a8abb39
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 6 deletions.
32 changes: 29 additions & 3 deletions pipegoose/nn/tensor_parallel/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,20 @@
from pipegoose.distributed.parallel_mode import ParallelMode
from pipegoose.nn.tensor_parallel._utils import VocabUtility, is_splitable
from pipegoose.nn.tensor_parallel.embedding import ParallelEmbedding
from pipegoose.nn.tensor_parallel.linear import ColumnParallelLinear


def _update_model_arguments(module: nn.Module, **kwargs):
for key, value in kwargs.items():
setattr(module, key, value)


def get_partition(data: torch.Tensor, parallel_context: ParallelContext, dim: int):
rank = parallel_context.get_local_rank(ParallelMode.TENSOR)
chunks = torch.chunk(data, parallel_context.get_world_size(ParallelMode.TENSOR), dim=dim)
return chunks[rank].contiguous()


class ParallelizeModule(ABC):
def __init__(self, module: nn.Module, parallel_context: ParallelContext):
self.module = module
Expand All @@ -29,12 +36,31 @@ def deparallelize(self):


class ParallelizeLinear(ParallelizeModule):
def parallelize(self):
pass
def parallelize(self) -> nn.Module:
self._parallelize_column_linear()
return self.module

def deparallelize(self):
pass

def _parallelize_column_linear(self):
self.module.weight.data = get_partition(self.module.weight, self.parallel_context, dim=0)

if self.module.bias is not None:
self.module.bias.data = get_partition(self.module.bias, self.parallel_context, dim=0)

self.module.__class__ = ColumnParallelLinear
_update_model_arguments(
module=self.module,
# NOTE: make this based on parallel mapping
# column parallel don't gather the output
gather_output=True,
parallel_context=self.parallel_context,
)

def _parallelize_row_linear(self):
pass


class ParallelizeEmbedding(ParallelizeModule):
# TODO: refactor to staticmethod
Expand Down Expand Up @@ -69,7 +95,7 @@ def _split_weight(self):
)

def _resize_vocab_size(self):
"""Make vocab size divisible by world size."""
"""Pad embedding size to make it splittable across GPUs"""
padding_size = 0

vocab_size, embedding_dim = self.module.weight.size()
Expand Down
48 changes: 45 additions & 3 deletions tests/nn/tensor_parallel/test_parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,19 @@
from pipegoose.distributed.parallel_context import ParallelContext
from pipegoose.distributed.parallel_mode import ParallelMode
from pipegoose.nn.tensor_parallel._utils import VocabUtility, is_splitable
from pipegoose.nn.tensor_parallel.parallelize import ParallelizeEmbedding
from pipegoose.nn.tensor_parallel.parallelize import (
ParallelizeEmbedding,
ParallelizeLinear,
)
from pipegoose.testing.utils import spawn

MODEL_NAME = "bigscience/bloom-560m"


@pytest.fixture
def model():
return AutoModel.from_pretrained(MODEL_NAME)


def init_parallel_context(rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size):
parallel_context = ParallelContext(
Expand Down Expand Up @@ -58,8 +68,7 @@ def get_new_embedding_size(vocab_size):


@pytest.mark.parametrize("tensor_parallel_size", [1, 2])
def test_parallelize_embedding(tensor_parallel_size):
model = AutoModel.from_pretrained("gpt2")
def test_parallelize_embedding(model, tensor_parallel_size):
input = torch.arange(0, 10)
embedding = model.get_input_embeddings()
output = embedding(input)
Expand All @@ -74,3 +83,36 @@ def test_parallelize_embedding(tensor_parallel_size):
input=input.detach(),
output=output.detach(),
)


def run_parallelize_linear(
rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size, linear, input, output
):
parallel_context = init_parallel_context(
rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size
)
parallelized_linear = ParallelizeLinear(linear, parallel_context).parallelize()
parallel_output = parallelized_linear(input)

torch.allclose(parallel_output, output, rtol=1e-4)


@pytest.mark.parametrize("tensor_parallel_size", [1, 2])
def test_parallelize_linear(model, tensor_parallel_size):
# NOTE: this is column parallel linear
linear = model.h[0].mlp.dense_h_to_4h
input_size = linear.weight.shape[1]

input = torch.randn(10, input_size)
output = linear(input)

spawn(
run_parallelize_linear,
world_size=tensor_parallel_size,
tensor_parallel_size=tensor_parallel_size,
pipeline_parallel_size=1,
data_parallel_size=1,
linear=linear,
input=input.detach(),
output=output.detach(),
)

0 comments on commit a8abb39

Please sign in to comment.