Skip to content

Commit

Permalink
[Feature] Add support for tensor parallelism in MoE
Browse files Browse the repository at this point in the history
  • Loading branch information
xrsrke committed Dec 6, 2023
1 parent e094986 commit dfda88c
Show file tree
Hide file tree
Showing 8 changed files with 116 additions and 37 deletions.
2 changes: 1 addition & 1 deletion pipegoose/distributed/_initializers/initialize_expert.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def init_dist_group(self) -> ProcessGroupResult:
process_group = None
local_world_size = None
ranks_in_group = None
parallel_mode = ParallelMode.EXPERT
parallel_mode = ParallelMode.EXPERT_DATA

for i in range(num_tensor_parallel_groups):
ranks = list(range(i * self.tensor_parallel_size, (i + 1) * self.tensor_parallel_size))
Expand Down
4 changes: 2 additions & 2 deletions pipegoose/distributed/parallel_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,12 +277,12 @@ def map_rank_to_device(self):
# NOTE: In 3D parallelism for MoE, the gpu assignment only depends on
# tensor parallelism, pipeline parallelism and data parallelism.
# according to the paper: Pipeline MoE: A Flexible MoE Implementatio
# with Pipeline Parallelism by Xin Chen et al
# with Pipeline Parallelism by Xin Chen et al
# https://arxiv.org/abs/2304.11414
modes_and_ranks = {
mode: rank
for mode, rank in zip(self._local_ranks.keys(), _rank_tensor.tolist())
if mode != ParallelMode.EXPERT
if mode != ParallelMode.EXPERT_DATA
}
self._ranks_to_device[tuple(modes_and_ranks.items())] = _rank

Expand Down
2 changes: 1 addition & 1 deletion pipegoose/distributed/parallel_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ class ParallelMode(Enum):
DATA = "data"

# NOTE: for expert data parallelism
EXPERT = "expert"
EXPERT_DATA = "expert"
2 changes: 1 addition & 1 deletion pipegoose/nn/data_parallel/data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,5 +39,5 @@ def _average_grad(self, grad: torch.Tensor, is_expert: bool):
grad,
op=dist.ReduceOp.SUM,
parallel_context=self.parallel_context,
parallel_mode=ParallelMode.EXPERT if is_expert else ParallelMode.DATA,
parallel_mode=ParallelMode.EXPERT_DATA if is_expert else ParallelMode.DATA,
)
22 changes: 21 additions & 1 deletion pipegoose/nn/tensor_parallel/tensor_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from torch import nn

from pipegoose.distributed.parallel_context import ParallelContext
from pipegoose.nn.expert_parallel.layers import ExpertLayer
from pipegoose.nn.parallel import Parallel
from pipegoose.nn.tensor_parallel.parallelizer import (
EmbeddingParallelizer,
Expand Down Expand Up @@ -42,10 +43,29 @@ def parallelize(self) -> nn.Module:
return module

def _get_leaf_modules(self, model: nn.Module) -> List[Tuple[str, nn.Module]]:
"""Return non-expert leaf modules."""
leaf_modules = []
expert_names = []

def is_child_of_expert(module_name):
# NOTE: suppose an mlp expert has name "transformer.h.0.mlp"
# then its children will have names like "transformer.h.0.mlp.{child_name}"
# so we can check if a module is a child of an expert by checking if its name
# starts with "transformer.h.0.mlp"
for expert_name in expert_names:
if module_name.startswith(expert_name):
return True
return False

for module_name, module in model.named_modules():
if list(module.children()):
if isinstance(module, ExpertLayer):
expert_names.append(module_name)
continue

# NOTE: skip leaf modules that belong to ExpertLayer
if is_child_of_expert(module_name) or list(module.children()):
continue

leaf_modules.append((module_name, module))

return leaf_modules
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def init_tensor_parallel_group(
assert result["ranks_in_group"] == expected_ranks
assert dist.get_process_group_ranks(result["process_group"]) == expected_ranks

assert result["parallel_mode"] == ParallelMode.EXPERT
assert result["parallel_mode"] == ParallelMode.EXPERT_DATA

dist.barrier()
dist.destroy_process_group(result["process_group"])
Expand Down
8 changes: 4 additions & 4 deletions tests/distributed/test_parallel_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@
LOCAL_RANK_TO_NEXT_RANK = {
1: {
ParallelMode.TENSOR: {0: 0},
ParallelMode.EXPERT: {0: 0},
ParallelMode.EXPERT_DATA: {0: 0},
ParallelMode.PIPELINE: {0: 0},
ParallelMode.DATA: {0: 0},
ParallelMode.GLOBAL: {0: 0},
},
8: {
ParallelMode.TENSOR: {0: 1, 1: 0},
ParallelMode.EXPERT: {0: 1, 1: 0},
ParallelMode.EXPERT_DATA: {0: 1, 1: 0},
ParallelMode.PIPELINE: {
0: 1,
1: 0,
Expand All @@ -37,14 +37,14 @@
LOCAL_RANK_TO_PREV_RANK = {
1: {
ParallelMode.TENSOR: {0: 0},
ParallelMode.EXPERT: {0: 0},
ParallelMode.EXPERT_DATA: {0: 0},
ParallelMode.PIPELINE: {0: 0},
ParallelMode.DATA: {0: 0},
ParallelMode.GLOBAL: {0: 0},
},
8: {
ParallelMode.TENSOR: {0: 1, 1: 0},
ParallelMode.EXPERT: {0: 1, 1: 0},
ParallelMode.EXPERT_DATA: {0: 1, 1: 0},
ParallelMode.PIPELINE: {
0: 1,
1: 0,
Expand Down
111 changes: 85 additions & 26 deletions tests/nn/expert_parallel/test_hybrid_expert_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from pipegoose.nn.data_parallel.data_parallel import DataParallel
from pipegoose.nn.expert_parallel.loss import ExpertLoss
from pipegoose.nn.expert_parallel.routers import SwitchNoisePolicy, Top1Router
from pipegoose.nn.tensor_parallel.tensor_parallel import TensorParallel
from pipegoose.testing.utils import get_microbatch, init_parallel_context, spawn

MODEL_NAME = "bigscience/bloom-560m"
Expand All @@ -30,6 +31,30 @@ def tokenizer():
return AutoTokenizer.from_pretrained(MODEL_NAME)


def get_inputs(model, tokenizer):
NUM_EXPERTS = 2
NUM_EXPERT_LAYERS = 2
NUM_LAYERS = model.config.num_hidden_layers
D_MODEL = model.config.hidden_size

mapping = [layer_idx for layer_idx in random.sample(range(NUM_LAYERS - 1), NUM_EXPERT_LAYERS)]
noise_policy = SwitchNoisePolicy()
router = Top1Router(noise_policy, NUM_EXPERTS, D_MODEL)

text = ["Persistence is all you need.", "Attention is all you need."]
input = tokenizer(text, return_tensors="pt", padding=True)

kwargs = {
"input": input,
"labels": input["input_ids"],
"model": model,
"mapping": mapping,
"num_experts": NUM_EXPERTS,
"router": router,
}
return kwargs


def run_expert_parallel_with_data_parallel(
rank,
world_size,
Expand All @@ -50,16 +75,11 @@ def run_expert_parallel_with_data_parallel(
torch.manual_seed(42)

parallel_context = init_parallel_context(
rank,
world_size,
port,
tensor_parallel_size,
pipeline_parallel_size,
data_parallel_size,
rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size
)
# NOTE: each model replicas only train on a subset of data
input_ids, attention_mask, labels = get_microbatch(
kwargs["input"], kwargs["labels"], parallel_context, ParallelMode.EXPERT
kwargs["input"], kwargs["labels"], parallel_context, ParallelMode.EXPERT_DATA
)
loss_func = ExpertLoss(nn.CrossEntropyLoss())

Expand All @@ -77,7 +97,7 @@ def run_expert_parallel_with_data_parallel(
loss.backward()

expert_grad = list(model.transformer.h[0].mlp.parameters())[0]
expert_grads = all_gather(expert_grad, parallel_context=parallel_context, parallel_mode=ParallelMode.EXPERT)
expert_grads = all_gather(expert_grad, parallel_context=parallel_context, parallel_mode=ParallelMode.EXPERT_DATA)
expert_grads = torch.chunk(expert_grads, chunks=data_parallel_size, dim=0)

# NOTE: check if expert grads are the same across data parallel dimension
Expand All @@ -92,29 +112,68 @@ def test_expert_parallel_with_data_parallel(model, tokenizer):
DATA_PARALLEL_SIZE = 2
WORLD_SIZE = TENSOR_PARALLEL_SIZE * PIPELINE_PARALLEL_SIZE * DATA_PARALLEL_SIZE

NUM_EXPERTS = 2
NUM_EXPERT_LAYERS = 2
NUM_LAYERS = model.config.num_hidden_layers
D_MODEL = model.config.hidden_size
kwargs = get_inputs(model, tokenizer)

mapping = [layer_idx for layer_idx in random.sample(range(NUM_LAYERS - 1), NUM_EXPERT_LAYERS)]
noise_policy = SwitchNoisePolicy()
router = Top1Router(noise_policy, NUM_EXPERTS, D_MODEL)
spawn(
run_expert_parallel_with_data_parallel,
world_size=WORLD_SIZE,
tensor_parallel_size=TENSOR_PARALLEL_SIZE,
pipeline_parallel_size=PIPELINE_PARALLEL_SIZE,
data_parallel_size=DATA_PARALLEL_SIZE,
kwargs=kwargs,
)

text = ["Persistence is all you need.", "Attention is all you need."]
input = tokenizer(text, return_tensors="pt", padding=True)

kwargs = {
"input": input,
"labels": input["input_ids"],
"model": model,
"mapping": mapping,
"num_experts": NUM_EXPERTS,
"router": router,
}
def run_expert_parallel_with_tensor_parallel(
rank,
world_size,
port,
tensor_parallel_size,
pipeline_parallel_size,
data_parallel_size,
kwargs,
):
model = kwargs["model"]
mapping = kwargs["mapping"]
router = kwargs["router"]
NUM_EXPERTS = kwargs["num_experts"]

# TODO: remove after adding seed to parallel_context
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

parallel_context = init_parallel_context(
rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size
)
loss_func = ExpertLoss(nn.CrossEntropyLoss())

model = ExpertParallel(model, NUM_EXPERTS, mapping=mapping, router=router, parallel_context=parallel_context).parallelize()
model = TensorParallel(model, parallel_context).parallelize()
optim = Adam(model.parameters(), lr=1e-3)

outputs = model(**kwargs["input"])

logits = outputs.logits[..., :-1, :].contiguous().view(-1, outputs.logits.shape[-1])
labels = kwargs["labels"][..., 1:].contiguous().view(-1).to(logits.device)
loss = loss_func(logits, labels)

optim.zero_grad()
loss.backward()

optim.step()


def test_expert_parallel_with_tensor_parallel(model, tokenizer):
TENSOR_PARALLEL_SIZE = 2
PIPELINE_PARALLEL_SIZE = 1
DATA_PARALLEL_SIZE = 1
WORLD_SIZE = TENSOR_PARALLEL_SIZE * PIPELINE_PARALLEL_SIZE * DATA_PARALLEL_SIZE

kwargs = get_inputs(model, tokenizer)

spawn(
run_expert_parallel_with_data_parallel,
run_expert_parallel_with_tensor_parallel,
world_size=WORLD_SIZE,
tensor_parallel_size=TENSOR_PARALLEL_SIZE,
pipeline_parallel_size=PIPELINE_PARALLEL_SIZE,
Expand Down

0 comments on commit dfda88c

Please sign in to comment.