Skip to content

Commit

Permalink
add add parallelize 🤗 transformer's layer norm
Browse files Browse the repository at this point in the history
  • Loading branch information
xrsrke committed Sep 4, 2023
1 parent 9ebeb47 commit dd12134
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 7 deletions.
19 changes: 14 additions & 5 deletions pipegoose/nn/tensor_parallel/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,6 @@ def _parallelize_column_linear(self, module: nn.Module) -> nn.Module:

_update_model_arguments(
module=module,
# TODO: make this based on parallel mapping
# column parallel don't gather the output
gather_output=True,
parallel_context=self.parallel_context,
)
Expand All @@ -68,8 +66,6 @@ def _parallelize_column_linear(self, module: nn.Module) -> nn.Module:

def _parallelize_row_linear(self, module: nn.Module) -> nn.Module:
module.__class__ = RowParallelLinear
# NOTE: It appears that row column doesn't not require splitting the bias,
# as the final output without splitting is correct
module = self._slice_weight_and_bias(module, slice_bias=False, dim=1)

_update_model_arguments(
Expand Down Expand Up @@ -137,7 +133,20 @@ def _resize_vocab_size(self):


class ParallelizeLayerNorm(ParallelizeModule):
pass
def parallelize(self) -> nn.Module:
assert isinstance(self.module, nn.LayerNorm), "only parallelize nn.LayerNorm"

_update_model_arguments(
module=self.module,
normalized_shape=self.module.normalized_shape,
eps=self.module.eps,
paralell_context=self.parallel_context
)

return self.module

def deparallelize(self):
pass


class ParallelizeAttention(ParallelizeModule):
Expand Down
48 changes: 46 additions & 2 deletions tests/nn/tensor_parallel/test_parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pipegoose.nn.tensor_parallel.parallelize import (
ParallelizeEmbedding,
ParallelizeLinear,
ParallelizeLayerNorm
)
from pipegoose.testing.utils import spawn

Expand Down Expand Up @@ -131,6 +132,49 @@ def test_parallelize_attention():
pass


@pytest.mark.skip
def test_parallelize_layer_norm():
def run_parallelize_layernorm(
rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size, module_name, module, input, output
):
parallel_context = init_parallel_context(
rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size
)
# TODO: make this based on parallel mapping
parallelized_module = ParallelizeLayerNorm(module, parallel_context).parallelize()
parallel_output = parallelized_module(input)

torch.allclose(parallel_output, output)

# NOTE: since we already test the backward pass of
# ColumnParallelLinear, and RowParallelLinear in another test,
# we don't need to test it here.


@pytest.mark.parametrize("tensor_parallel_size", [1, 2])
def test_parallelize_layer_norm(model, tensor_parallel_size):
DATA_PARALLEL_SIZE = 1
PIPELINE_PARALLEL_SIZE = 1

MODULE_NAME = "transformer.word_embeddings_layernorm"
module = model.transformer.word_embeddings_layernorm

BATCH_SIZE = 10
SEQ_LEN = 5
HIDDEN_SIZE = module.normalized_shape[0]
input = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE)
output = module(input)

spawn(
run_parallelize_layernorm,
world_size=tensor_parallel_size,
tensor_parallel_size=tensor_parallel_size,
pipeline_parallel_size=PIPELINE_PARALLEL_SIZE,
data_parallel_size=DATA_PARALLEL_SIZE,
module_name=MODULE_NAME,
module=module,
input=input.detach(),
output=output.detach(),
)


def test_parallelize_positional_embedding():
pass

0 comments on commit dd12134

Please sign in to comment.