Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
xrsrke committed Sep 2, 2023
1 parent 7c9642b commit c15c80b
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 16 deletions.
7 changes: 5 additions & 2 deletions tests/nn/tensor_parallel/test_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ def get_partition(x):

@pytest.mark.parametrize("tensor_parallel_size", [1, 2])
def test_parallel_embedding(tensor_parallel_size):
PIPELINE_PARALLEL_SIZE = 1
DATA_PARALLEL_SIZE = 1

NUM_EMBEDDING = 100
EMBEDDING_DIM = 10

Expand All @@ -67,8 +70,8 @@ def test_parallel_embedding(tensor_parallel_size):
run_parallel_embedding,
world_size=tensor_parallel_size,
tensor_parallel_size=tensor_parallel_size,
pipeline_parallel_size=1,
data_parallel_size=1,
pipeline_parallel_size=PIPELINE_PARALLEL_SIZE,
data_parallel_size=DATA_PARALLEL_SIZE,
input=input.detach(),
output=output.detach(),
weight=weight.detach(),
Expand Down
12 changes: 8 additions & 4 deletions tests/nn/tensor_parallel/test_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,10 @@ def run_parallel_row_linear(

@pytest.mark.parametrize("run_linear", [run_parallel_column_linear, run_parallel_row_linear])
def test_parallel_linear(run_linear):
TENSOR_PARALLEL_SIZE = 2
PIPELINE_PARALLEL_SIZE = 1
DATA_PARALLEL_SIZE = 1

batch_size = 5
in_features = 10
out_features = 20
Expand All @@ -150,10 +154,10 @@ def test_parallel_linear(run_linear):

spawn(
run_linear,
world_size=2,
tensor_parallel_size=2,
pipeline_parallel_size=1,
data_parallel_size=1,
world_size=TENSOR_PARALLEL_SIZE,
tensor_parallel_size=TENSOR_PARALLEL_SIZE,
pipeline_parallel_size=PIPELINE_PARALLEL_SIZE,
data_parallel_size=DATA_PARALLEL_SIZE,
batch_size=batch_size,
in_features=in_features,
out_features=out_features,
Expand Down
10 changes: 7 additions & 3 deletions tests/nn/tensor_parallel/test_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,15 @@ def get_partition(logits):

@pytest.mark.parametrize("tensor_parallel_size", [1, 2])
def test_parallel_cross_entropy(tensor_parallel_size):
torch.manual_seed(69)
PIPELINE_PARALLEL_SIZE = 1
DATA_PARALLEL_SIZE = 1

BATCH_SIZE = 1
SEQ_LEN = 2
VOCAB_SIZE = 4

torch.manual_seed(69)

logits = torch.randn(BATCH_SIZE, SEQ_LEN, VOCAB_SIZE)
targets = torch.randint(0, VOCAB_SIZE, (BATCH_SIZE, SEQ_LEN))

Expand All @@ -72,8 +76,8 @@ def test_parallel_cross_entropy(tensor_parallel_size):
run_parallel_cross_entropy,
world_size=tensor_parallel_size,
tensor_parallel_size=tensor_parallel_size,
pipeline_parallel_size=1,
data_parallel_size=1,
pipeline_parallel_size=PIPELINE_PARALLEL_SIZE,
data_parallel_size=DATA_PARALLEL_SIZE,
logits=logits,
targets=targets,
loss=loss,
Expand Down
20 changes: 13 additions & 7 deletions tests/nn/tensor_parallel/test_parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ def run_parallelize_embedding(

@pytest.mark.parametrize("tensor_parallel_size", [1, 2])
def test_parallelize_embedding(model, tensor_parallel_size):
PIPELINE_PARALLEL_SIZE = 1
DATA_PARALLEL_SIZE = 1

input = torch.arange(0, 10)
embedding = model.get_input_embeddings()
output = embedding(input)
Expand All @@ -56,8 +59,8 @@ def test_parallelize_embedding(model, tensor_parallel_size):
run_parallelize_embedding,
world_size=tensor_parallel_size,
tensor_parallel_size=tensor_parallel_size,
pipeline_parallel_size=1,
data_parallel_size=1,
pipeline_parallel_size=PIPELINE_PARALLEL_SIZE,
data_parallel_size=DATA_PARALLEL_SIZE,
embedding=embedding,
input=input.detach(),
output=output.detach(),
Expand All @@ -76,9 +79,9 @@ def run_parallelize_linear(

torch.allclose(parallel_output, output, rtol=1e-4)

# NOTE: since we already test the backward pass
# of ColumnParallelLinear in another test, we don't
# need to test it here
# NOTE: since we already test the backward pass of
# ColumnParallelLinear, and RowParallelLinear in another test,
# we don't need to test it here.


@pytest.mark.parametrize("tensor_parallel_size, MODULE_NAME, get_module", [
Expand All @@ -88,6 +91,9 @@ def run_parallelize_linear(
(2, "transformer.h.0.mlp.dense_4h_to_h", lambda model: model.h[0].mlp.dense_4h_to_h),
])
def test_parallelize_linear(model, tensor_parallel_size, MODULE_NAME, get_module):
PIPELINE_PARALLEL_SIZE = 1
DATA_PARALLEL_SIZE = 1

module = get_module(model)
input_size = module.weight.shape[1]

Expand All @@ -98,8 +104,8 @@ def test_parallelize_linear(model, tensor_parallel_size, MODULE_NAME, get_module
run_parallelize_linear,
world_size=tensor_parallel_size,
tensor_parallel_size=tensor_parallel_size,
pipeline_parallel_size=1,
data_parallel_size=1,
pipeline_parallel_size=PIPELINE_PARALLEL_SIZE,
data_parallel_size=DATA_PARALLEL_SIZE,
module_name=MODULE_NAME,
module=module,
input=input_tensor.detach(),
Expand Down

0 comments on commit c15c80b

Please sign in to comment.