Skip to content

Commit

Permalink
change parallelize a model from AutoModelForCausalLM
Browse files Browse the repository at this point in the history
  • Loading branch information
xrsrke committed Sep 2, 2023
1 parent c15c80b commit 9ac8531
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
4 changes: 2 additions & 2 deletions tests/nn/tensor_parallel/conftest.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import pytest
from transformers import AutoModel
from transformers import AutoModelForCausalLM

MODEL_NAME = "bigscience/bloom-560m"


@pytest.fixture(scope="session")
def model():
return AutoModel.from_pretrained(MODEL_NAME)
return AutoModelForCausalLM.from_pretrained(MODEL_NAME)
8 changes: 4 additions & 4 deletions tests/nn/tensor_parallel/test_parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,10 @@ def run_parallelize_linear(


@pytest.mark.parametrize("tensor_parallel_size, MODULE_NAME, get_module", [
(1, "transformer.h.0.mlp.dense_h_to_4h", lambda model: model.h[0].mlp.dense_h_to_4h),
(2, "transformer.h.0.mlp.dense_h_to_4h", lambda model: model.h[0].mlp.dense_h_to_4h),
(1, "transformer.h.0.mlp.dense_4h_to_h", lambda model: model.h[0].mlp.dense_4h_to_h),
(2, "transformer.h.0.mlp.dense_4h_to_h", lambda model: model.h[0].mlp.dense_4h_to_h),
(1, "transformer.h.0.mlp.dense_h_to_4h", lambda model: model.transformer.h[0].mlp.dense_h_to_4h),
(2, "transformer.h.0.mlp.dense_h_to_4h", lambda model: model.transformer.h[0].mlp.dense_h_to_4h),
(1, "transformer.h.0.mlp.dense_4h_to_h", lambda model: model.transformer.h[0].mlp.dense_4h_to_h),
(2, "transformer.h.0.mlp.dense_4h_to_h", lambda model: model.transformer.h[0].mlp.dense_4h_to_h),
])
def test_parallelize_linear(model, tensor_parallel_size, MODULE_NAME, get_module):
PIPELINE_PARALLEL_SIZE = 1
Expand Down

0 comments on commit 9ac8531

Please sign in to comment.