Skip to content

Commit

Permalink
Add parallel mapping for whether a module is row-linear parallelizabl…
Browse files Browse the repository at this point in the history
…e or column-linear parallelizable
  • Loading branch information
xrsrke committed Sep 1, 2023
1 parent 7675d05 commit 5ff9c6c
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 33 deletions.
28 changes: 20 additions & 8 deletions pipegoose/nn/tensor_parallel/parallel_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,26 +27,38 @@ class ParallelMapping:
],
}

def _search(cls, module_name: str) -> TensorParallelInformation:
@staticmethod
def _search(module_name: str) -> TensorParallelInformation:
"""
Search for module_name in mappings.
"""
for _, items in cls.__MAPPING__.items():
module_name = ParallelMapping._extract_module_name(module_name)
for _, items in ParallelMapping.__MAPPING__.items():
for item in items:
if module_name in item.module_name:
return item
return None

@classmethod
def is_column_parallel(cls, module_name: str) -> bool:
item = cls._search(module_name)
@staticmethod
def _extract_module_name(module_name: str) -> str:
def _check_module_name_in_named_modules(module_name: str) -> bool:
return '.' in module_name

if _check_module_name_in_named_modules(module_name) is True:
return module_name.split('.')[-1]

return module_name

@staticmethod
def is_column_parallel(module_name: str) -> bool:
item = ParallelMapping._search(module_name)
if item is None:
return False
return isinstance(item, Column)

@classmethod
def is_row_parallel(cls, module_name: str) -> bool:
item = cls._search(module_name)
@staticmethod
def is_row_parallel(module_name: str) -> bool:
item = ParallelMapping._search(module_name)
if item is None:
return False
return isinstance(item, Row)
33 changes: 16 additions & 17 deletions tests/nn/tensor_parallel/test_parallel_mapping.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,23 @@
# from transformers import AutoModel
from pipegoose.nn.tensor_parallel.parallel_mapping import ParallelMapping

# from pipegoose.nn.tensor_parallel.parallel_mapping import Column, ParallelMapping, Row

def test_is_column_parallel_mapping(model):
mappings = {}

def test_parallel_mapping():
# model = AutoModel.from_pretrained("bigscience/bloom-560m")
# EXPECTED_MAPPING = {
# "dense_h_to_4h": Column,
# "dense_4h_to_h": Row,
# }
for name, _ in model.named_modules():
mappings[name] = ParallelMapping.is_column_parallel(name)

# mappings = {}
for layer_idx in range(len(model.h)):
# TODO: add check attention layer
assert mappings[f"h.{layer_idx}.mlp.dense_h_to_4h"] is True

# for name, module in model.named_modules():
# mappings[name] = ParallelMapping.is_column_parallel(module)

# for name, parallel_type in mappings.items():
# assert isinstance(parallel_type, EXPECTED_MAPPING[name])
def test_is_row_parallel_mapping(model):
mappings = {}

# module = model.h[-1].mlp.dense_4h_to_h
# output = ParallelMapping.is_row_parallel(module)
# assert isinstance(output, Row)
pass
for name, _ in model.named_modules():
mappings[name] = ParallelMapping.is_row_parallel(name)

for layer_idx in range(len(model.h)):
# TODO: add check attention layer
assert mappings[f"h.{layer_idx}.mlp.dense_4h_to_h"] is True
16 changes: 8 additions & 8 deletions tests/nn/tensor_parallel/test_parallelize.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import pytest
import torch
from transformers import AutoModel

from pipegoose.distributed.parallel_context import ParallelContext
from pipegoose.distributed.parallel_mode import ParallelMode
Expand All @@ -11,13 +10,6 @@
)
from pipegoose.testing.utils import spawn

MODEL_NAME = "bigscience/bloom-560m"


@pytest.fixture
def model():
return AutoModel.from_pretrained(MODEL_NAME)


def init_parallel_context(rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size):
parallel_context = ParallelContext(
Expand Down Expand Up @@ -89,6 +81,10 @@ def test_parallelize_embedding(model, tensor_parallel_size):
)


def test_parallelize_attention():
pass


def run_parallelize_column_linear(
rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size, linear, input, output
):
Expand Down Expand Up @@ -157,3 +153,7 @@ def test_parallelize_row_linear(model, tensor_parallel_size):
input=input.detach(),
output=output.detach(),
)


def test_parallelize_layer_norm():
pass

0 comments on commit 5ff9c6c

Please sign in to comment.