Skip to content

Commit

Permalink
merge tests for parallelizing column linear and row linear operations…
Browse files Browse the repository at this point in the history
… into one
  • Loading branch information
xrsrke committed Sep 2, 2023
1 parent 8f3ad92 commit 7c9642b
Showing 1 changed file with 11 additions and 32 deletions.
43 changes: 11 additions & 32 deletions tests/nn/tensor_parallel/test_parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,39 +81,18 @@ def run_parallelize_linear(
# need to test it here


@pytest.mark.parametrize("tensor_parallel_size", [1, 2])
def test_parallelize_column_linear(model, tensor_parallel_size):
# TODO: add test module the named_children() version
MODULE_NAME = "transformer.h.0.mlp.dense_h_to_4h"

# NOTE: this is column parallel linear
module = model.h[0].mlp.dense_h_to_4h
input_size = module.weight.shape[1]

input = torch.randn(10, input_size)
output = module(input)

spawn(
run_parallelize_linear,
world_size=tensor_parallel_size,
tensor_parallel_size=tensor_parallel_size,
pipeline_parallel_size=1,
data_parallel_size=1,
module_name=MODULE_NAME,
module=module,
input=input.detach(),
output=output.detach(),
)


@pytest.mark.parametrize("tensor_parallel_size", [1, 2])
def test_parallelize_row_linear(model, tensor_parallel_size):
MODULE_NAME = "transformer.h.0.mlp.dense_4h_to_h"
module = model.h[0].mlp.dense_4h_to_h
@pytest.mark.parametrize("tensor_parallel_size, MODULE_NAME, get_module", [
(1, "transformer.h.0.mlp.dense_h_to_4h", lambda model: model.h[0].mlp.dense_h_to_4h),
(2, "transformer.h.0.mlp.dense_h_to_4h", lambda model: model.h[0].mlp.dense_h_to_4h),
(1, "transformer.h.0.mlp.dense_4h_to_h", lambda model: model.h[0].mlp.dense_4h_to_h),
(2, "transformer.h.0.mlp.dense_4h_to_h", lambda model: model.h[0].mlp.dense_4h_to_h),
])
def test_parallelize_linear(model, tensor_parallel_size, MODULE_NAME, get_module):
module = get_module(model)
input_size = module.weight.shape[1]

input = torch.randn(10, input_size)
output = module(input)
input_tensor = torch.randn(10, input_size)
output = module(input_tensor)

spawn(
run_parallelize_linear,
Expand All @@ -123,7 +102,7 @@ def test_parallelize_row_linear(model, tensor_parallel_size):
data_parallel_size=1,
module_name=MODULE_NAME,
module=module,
input=input.detach(),
input=input_tensor.detach(),
output=output.detach(),
)

Expand Down

0 comments on commit 7c9642b

Please sign in to comment.