Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
xrsrke committed Sep 2, 2023
1 parent 3c34cc7 commit 8f3ad92
Showing 1 changed file with 6 additions and 33 deletions.
39 changes: 6 additions & 33 deletions tests/nn/tensor_parallel/test_parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

from pipegoose.distributed.parallel_context import ParallelContext
from pipegoose.distributed.parallel_mode import ParallelMode
from pipegoose.nn.tensor_parallel._utils import VocabUtility, is_splitable
from pipegoose.nn.tensor_parallel.parallelize import (
ParallelizeEmbedding,
ParallelizeLinear,
Expand Down Expand Up @@ -37,25 +36,9 @@ def run_parallelize_embedding(
)
world_size = parallel_context.get_world_size(parallel_mode=ParallelMode.TENSOR)

def get_new_embedding_size(vocab_size):
padding_size = 0
while not is_splitable(vocab_size + padding_size, parallel_context):
padding_size += 1

new_vocab_size = vocab_size + padding_size
new_partition_size = new_vocab_size // world_size
return new_vocab_size, new_partition_size

vocab_size, embedding_dim = embedding.weight.size()
new_vocab_size, new_partition_size = get_new_embedding_size(vocab_size)
vocab_start_idx, vocab_end_idx = VocabUtility.get_vocab_range_from_global_vocab_size(world_size, rank, new_vocab_size)

parallelized_embedding = ParallelizeEmbedding(embedding, parallel_context).parallelize()
parallel_output = parallelized_embedding(input)

assert parallelized_embedding.vocab_start_idx == vocab_start_idx
assert parallelized_embedding.vocab_end_idx == vocab_end_idx
assert parallelized_embedding.weight.shape == (new_partition_size, embedding_dim)
assert torch.allclose(parallel_output, output)

# NOTE: since we already test the backward pass
Expand All @@ -81,12 +64,13 @@ def test_parallelize_embedding(model, tensor_parallel_size):
)


def run_parallelize_column_linear(
def run_parallelize_linear(
rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size, module_name, module, input, output
):
parallel_context = init_parallel_context(
rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size
)
# TODO: make this based on parallel mapping
parallelized_module = ParallelizeLinear(module, parallel_context).parallelize(module_name)
parallel_output = parallelized_module(input)

Expand All @@ -110,7 +94,7 @@ def test_parallelize_column_linear(model, tensor_parallel_size):
output = module(input)

spawn(
run_parallelize_column_linear,
run_parallelize_linear,
world_size=tensor_parallel_size,
tensor_parallel_size=tensor_parallel_size,
pipeline_parallel_size=1,
Expand All @@ -122,19 +106,6 @@ def test_parallelize_column_linear(model, tensor_parallel_size):
)


def run_parallelize_row_linear(
rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size, module_name, module, input, output
):
parallel_context = init_parallel_context(
rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size
)
# TODO: make this based on parallel mapping
parallelized_module = ParallelizeLinear(module, parallel_context).parallelize(module_name)
parallel_output = parallelized_module(input)

torch.allclose(parallel_output, output, rtol=1e-4)


@pytest.mark.parametrize("tensor_parallel_size", [1, 2])
def test_parallelize_row_linear(model, tensor_parallel_size):
MODULE_NAME = "transformer.h.0.mlp.dense_4h_to_h"
Expand All @@ -145,7 +116,7 @@ def test_parallelize_row_linear(model, tensor_parallel_size):
output = module(input)

spawn(
run_parallelize_row_linear,
run_parallelize_linear,
world_size=tensor_parallel_size,
tensor_parallel_size=tensor_parallel_size,
pipeline_parallel_size=1,
Expand All @@ -157,9 +128,11 @@ def test_parallelize_row_linear(model, tensor_parallel_size):
)


@pytest.mark.skip
def test_parallelize_attention():
pass


@pytest.mark.skip
def test_parallelize_layer_norm():
pass

0 comments on commit 8f3ad92

Please sign in to comment.