Skip to content

Commit

Permalink
add parallelize 🤗 transformer's language model head
Browse files Browse the repository at this point in the history
  • Loading branch information
xrsrke committed Sep 4, 2023
1 parent dd12134 commit 1a9d67c
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 7 deletions.
23 changes: 23 additions & 0 deletions pipegoose/nn/tensor_parallel/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,3 +151,26 @@ def deparallelize(self):

class ParallelizeAttention(ParallelizeModule):
pass


class ParallelizeLMHead(ParallelizeModule):
"""Parallelize language model head."""
def parallelize(self) -> nn.Module:
module = self.module
module.__class__ = ColumnParallelLinear
module = self._slice_weight(module, dim=0)

_update_model_arguments(
module=module,
gather_output=True,
parallel_context=self.parallel_context,
)

return module

def _slice_weight(self, module: nn.Module, dim: int) -> nn.Module:
module.weight.data = get_partition(module.weight, self.parallel_context, dim=dim)
return module

def deparallelize(self):
pass
12 changes: 6 additions & 6 deletions tests/nn/tensor_parallel/test_parallel_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,28 @@


def test_is_column_parallel_mapping(model):
BLOOM_DENSE_H_TO_4H_NAME = "h.{}.mlp.dense_h_to_4h"
BLOOM_QKV_NAME = "h.{}.self_attention.query_key_value"
BLOOM_DENSE_H_TO_4H_NAME = "transformer.h.{}.mlp.dense_h_to_4h"
BLOOM_QKV_NAME = "transformer.h.{}.self_attention.query_key_value"
mappings = {}

for name, _ in model.named_modules():
mappings[name] = ParallelMapping.is_column_parallel(name)

for layer_idx in range(len(model.h)):
for layer_idx in range(len(model.transformer.h)):
assert mappings[BLOOM_DENSE_H_TO_4H_NAME.format(layer_idx)] is True
assert mappings[BLOOM_QKV_NAME.format(layer_idx)] is True


def test_is_row_parallel_mapping(model):
BLOOM_DENSE_4H_TO_H_NAME = "h.{}.mlp.dense_4h_to_h"
BLOOM_ATTENTION_DENSE_NAME = "h.{}.self_attention.dense"
BLOOM_DENSE_4H_TO_H_NAME = "transformer.h.{}.mlp.dense_4h_to_h"
BLOOM_ATTENTION_DENSE_NAME = "transformer.h.{}.self_attention.dense"

mappings = {}

for name, _ in model.named_modules():
mappings[name] = ParallelMapping.is_row_parallel(name)

for layer_idx in range(len(model.h)):
for layer_idx in range(len(model.transformer.h)):
# TODO: add check attention layer
assert mappings[BLOOM_DENSE_4H_TO_H_NAME.format(layer_idx)] is True
assert mappings[BLOOM_ATTENTION_DENSE_NAME.format(layer_idx)] is True
43 changes: 42 additions & 1 deletion tests/nn/tensor_parallel/test_parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from pipegoose.nn.tensor_parallel.parallelize import (
ParallelizeEmbedding,
ParallelizeLinear,
ParallelizeLayerNorm
ParallelizeLayerNorm,
ParallelizeLMHead
)
from pipegoose.testing.utils import spawn

Expand Down Expand Up @@ -178,3 +179,43 @@ def test_parallelize_layer_norm(model, tensor_parallel_size):

def test_parallelize_positional_embedding():
pass


def run_parallelize_lm_head(
rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size, module_name, module, input, output
):
parallel_context = init_parallel_context(
rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size
)

parallelized_module = ParallelizeLMHead(module, parallel_context).parallelize()
parallel_output = parallelized_module(input)

torch.allclose(parallel_output, output)


@pytest.mark.parametrize("tensor_parallel_size", [1, 2])
def test_parallelize_lm_head(model, tensor_parallel_size):
DATA_PARALLEL_SIZE = 1
PIPELINE_PARALLEL_SIZE = 1

MODULE_NAME = "lm_head"
module = model.lm_head

BATCH_SIZE = 10
HIDDEN_SIZE = module.weight.shape[1]

input = torch.randn(BATCH_SIZE, HIDDEN_SIZE)
output = module(input)

spawn(
run_parallelize_lm_head,
world_size=tensor_parallel_size,
tensor_parallel_size=tensor_parallel_size,
pipeline_parallel_size=PIPELINE_PARALLEL_SIZE,
data_parallel_size=DATA_PARALLEL_SIZE,
module_name=MODULE_NAME,
module=module,
input=input.detach(),
output=output.detach(),
)

0 comments on commit 1a9d67c

Please sign in to comment.