Skip to content

Commit

Permalink
Add parallel mapping to determine whether a module in an attention la…
Browse files Browse the repository at this point in the history
…yer is row-linear parallelizable or column-linear parallelizable.
  • Loading branch information
xrsrke committed Sep 1, 2023
1 parent 5ff9c6c commit 9401783
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 5 deletions.
4 changes: 2 additions & 2 deletions pipegoose/nn/tensor_parallel/parallel_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ class ParallelMapping:
__MAPPING__ = {
"albert-base-v2": [Column(("query", "key", "value")), Row("attention.dense")],
"bloom-560m": [
Column(("dense_h_to_4h",)),
Row(("dense_4h_to_h",)),
Column(("dense_h_to_4h", "query_key_value")),
Row(("dense_4h_to_h", "dense")),
],
}

Expand Down
12 changes: 9 additions & 3 deletions tests/nn/tensor_parallel/test_parallel_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,28 @@


def test_is_column_parallel_mapping(model):
BLOOM_DENSE_H_TO_4H_NAME = "h.{}.mlp.dense_h_to_4h"
BLOOM_QKV_NAME = "h.{}.self_attention.query_key_value"
mappings = {}

for name, _ in model.named_modules():
mappings[name] = ParallelMapping.is_column_parallel(name)

for layer_idx in range(len(model.h)):
# TODO: add check attention layer
assert mappings[f"h.{layer_idx}.mlp.dense_h_to_4h"] is True
assert mappings[BLOOM_DENSE_H_TO_4H_NAME.format(layer_idx)] is True
assert mappings[BLOOM_QKV_NAME.format(layer_idx)] is True


def test_is_row_parallel_mapping(model):
BLOOM_DENSE_4H_TO_H_NAME = "h.{}.mlp.dense_4h_to_h"
BLOOM_ATTENTION_DENSE_NAME = "h.{}.self_attention.dense"

mappings = {}

for name, _ in model.named_modules():
mappings[name] = ParallelMapping.is_row_parallel(name)

for layer_idx in range(len(model.h)):
# TODO: add check attention layer
assert mappings[f"h.{layer_idx}.mlp.dense_4h_to_h"] is True
assert mappings[BLOOM_DENSE_4H_TO_H_NAME.format(layer_idx)] is True
assert mappings[BLOOM_ATTENTION_DENSE_NAME.format(layer_idx)] is True

0 comments on commit 9401783

Please sign in to comment.