Skip to content

Commit

Permalink
refactor and parallelize the two dense layers in 🤗 transformer's MLP …
Browse files Browse the repository at this point in the history
…layer using the same API .parallelize()
  • Loading branch information
xrsrke committed Sep 2, 2023
1 parent 55533cb commit 3c34cc7
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 121 deletions.
4 changes: 4 additions & 0 deletions pipegoose/nn/tensor_parallel/parallel_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,7 @@ def is_row_parallel(module_name: str) -> bool:
if item is None:
return False
return isinstance(item, Row)

@staticmethod
def is_linear(module_name: str) -> bool:
pass
13 changes: 11 additions & 2 deletions pipegoose/nn/tensor_parallel/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from pipegoose.distributed.parallel_mode import ParallelMode
from pipegoose.nn.tensor_parallel.embedding import ParallelEmbedding
from pipegoose.nn.tensor_parallel.linear import ColumnParallelLinear, RowParallelLinear
from pipegoose.nn.tensor_parallel.parallel_mapping import ParallelMapping
from pipegoose.nn.tensor_parallel._utils import VocabUtility, is_splitable


Expand Down Expand Up @@ -36,8 +37,16 @@ def deparallelize(self):


class ParallelizeLinear(ParallelizeModule):
def parallelize(self) -> nn.Module:
module = self._parallelize_column_linear(self.module)
def parallelize(self, module_name: str) -> nn.Module:
assert isinstance(self.module, nn.Linear), "only parallelize nn.Linear"

if ParallelMapping.is_column_parallel(module_name):
module = self._parallelize_column_linear(self.module)
elif ParallelMapping.is_row_parallel(module_name):
module = self._parallelize_row_linear(self.module)
else:
raise ValueError(f"module {module_name} is not supported")

return module

def deparallelize(self):
Expand Down
36 changes: 7 additions & 29 deletions pipegoose/nn/tensor_parallel/tensor_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,15 @@

from pipegoose.distributed.parallel_context import ParallelContext
from pipegoose.nn.tensor_parallel.parallelize import (
ParallelizeAttention,
# ParallelizeAttention,
ParallelizeEmbedding,
ParallelizeLayerNorm,
ParallelizeLinear,
)


class TensorParallel:
"""Turn a sequential model into a tensor-parallel model.
Inspired by OSLO's TensorParallel: https://github.com/EleutherAI/oslo/blob/00e3be56446df37a0372a93a094863ffc89a2f8b/oslo/torch/nn/parallel/tensor_parallel/tensor_parallel.py#L51
"""
"""Turn a transformers model into a tensor parallel model."""

def __init__(self, module: nn.Module, parallel_context: ParallelContext):
super().__init__()
Expand All @@ -23,29 +20,10 @@ def __init__(self, module: nn.Module, parallel_context: ParallelContext):

@torch.no_grad()
def parallelize(self):
paralleler = {
"linear": ParallelizeLinear,
"embedding": ParallelizeEmbedding,
"layer_norm": ParallelizeLayerNorm,
"attention": ParallelizeAttention,
}

for name, module in self.module.named_modules():
if name in paralleler:
paralleler[name](module, self.parallel_context).parallelize()

def _parallelize_embedding(self):
for module_name, module in self.module.named_modules():
if isinstance(module, nn.Embedding):
pass

def _parallelize_linear(self):
pass

def _parallize_layernorm(self):
for _, module in self.module.named_modules():
if isinstance(module, nn.LayerNorm):
pass

def _resize_vocab_size(self, module: nn.Module):
pass
ParallelizeEmbedding(module_name, module, self.parallel_context).parallelize()
elif isinstance(module, nn.Linear):
ParallelizeLinear(module_name, module, self.parallel_context).parallelize()
elif isinstance(module, nn.LayerNorm):
ParallelizeLayerNorm(module_name, module, self.parallel_context).parallelize()
63 changes: 0 additions & 63 deletions tests/nn/tensor_parallel/test_layers.py

This file was deleted.

42 changes: 24 additions & 18 deletions tests/nn/tensor_parallel/test_parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,18 +81,14 @@ def test_parallelize_embedding(model, tensor_parallel_size):
)


def test_parallelize_attention():
pass


def run_parallelize_column_linear(
rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size, linear, input, output
rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size, module_name, module, input, output
):
parallel_context = init_parallel_context(
rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size
)
parallelized_linear = ParallelizeLinear(linear, parallel_context).parallelize()
parallel_output = parallelized_linear(input)
parallelized_module = ParallelizeLinear(module, parallel_context).parallelize(module_name)
parallel_output = parallelized_module(input)

torch.allclose(parallel_output, output, rtol=1e-4)

Expand All @@ -103,57 +99,67 @@ def run_parallelize_column_linear(

@pytest.mark.parametrize("tensor_parallel_size", [1, 2])
def test_parallelize_column_linear(model, tensor_parallel_size):
# TODO: add test module the named_children() version
MODULE_NAME = "transformer.h.0.mlp.dense_h_to_4h"

# NOTE: this is column parallel linear
linear = model.h[0].mlp.dense_h_to_4h
input_size = linear.weight.shape[1]
module = model.h[0].mlp.dense_h_to_4h
input_size = module.weight.shape[1]

input = torch.randn(10, input_size)
output = linear(input)
output = module(input)

spawn(
run_parallelize_column_linear,
world_size=tensor_parallel_size,
tensor_parallel_size=tensor_parallel_size,
pipeline_parallel_size=1,
data_parallel_size=1,
linear=linear,
module_name=MODULE_NAME,
module=module,
input=input.detach(),
output=output.detach(),
)


def run_parallelize_row_linear(
rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size, linear, input, output
rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size, module_name, module, input, output
):
parallel_context = init_parallel_context(
rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size
)
# TODO: make this based on parallel mapping
parallelized_linear = ParallelizeLinear(linear, parallel_context)._parallelize_row_linear(linear)
parallel_output = parallelized_linear(input)
parallelized_module = ParallelizeLinear(module, parallel_context).parallelize(module_name)
parallel_output = parallelized_module(input)

torch.allclose(parallel_output, output, rtol=1e-4)


@pytest.mark.parametrize("tensor_parallel_size", [1, 2])
def test_parallelize_row_linear(model, tensor_parallel_size):
linear = model.h[0].mlp.dense_4h_to_h
input_size = linear.weight.shape[1]
MODULE_NAME = "transformer.h.0.mlp.dense_4h_to_h"
module = model.h[0].mlp.dense_4h_to_h
input_size = module.weight.shape[1]

input = torch.randn(10, input_size)
output = linear(input)
output = module(input)

spawn(
run_parallelize_row_linear,
world_size=tensor_parallel_size,
tensor_parallel_size=tensor_parallel_size,
pipeline_parallel_size=1,
data_parallel_size=1,
linear=linear,
module_name=MODULE_NAME,
module=module,
input=input.detach(),
output=output.detach(),
)


def test_parallelize_attention():
pass


def test_parallelize_layer_norm():
pass
12 changes: 3 additions & 9 deletions tests/nn/tensor_parallel/test_tensor_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,13 @@ def init_parallel_context(rank, world_size, port, tensor_parallel_size, pipeline
@pytest.mark.skip
def test_parallelize_a_transformers():
parallel_context = init_parallel_context()
world_size = parallel_context.get_world_size(parallel_mode=ParallelMode.TENSOR)
# world_size = parallel_context.get_world_size(parallel_mode=ParallelMode.TENSOR)

model = AutoModel.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
input = tokenizer.tokenize("Persistence is all you need", return_tensors="pt")

with pytest.raises(ValueError):
vocab_size = model.get_input_embeddings().weight.shape[0]
assert vocab_size % world_size == 0
input = tokenizer.tokenize("Persistence is all you need.", return_tensors="pt")

parallelized_model = TensorParallel(model, parallel_context)
parallelized_model.parallelize()

assert vocab_size % world_size == 0

parallelized_model(**input)
generated_ids = parallelized_model(**input)

0 comments on commit 3c34cc7

Please sign in to comment.