Skip to content

Commit

Permalink
[Refactor] Refactor MoE with data parallelism's tests
Browse files Browse the repository at this point in the history
  • Loading branch information
xrsrke committed Dec 5, 2023
1 parent 4d7b344 commit e094986
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 20 deletions.
15 changes: 14 additions & 1 deletion pipegoose/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import random
import socket
from functools import partial
from typing import Callable
from typing import Callable, Tuple

import pytest
import torch
Expand Down Expand Up @@ -118,3 +118,16 @@ def calculate_parameter_similarity(module1: nn.Module, module2: nn.Module, rtol:

def count_model_parameters(model):
return sum(p.numel() for p in model.parameters())


def get_microbatch(
inputs, labels, parallel_context: ParallelContext, parallel_mode: ParallelMode
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
local_rank = parallel_context.get_local_rank(parallel_mode)
world_size = parallel_context.get_world_size(parallel_mode)

input_chunks = torch.chunk(inputs["input_ids"], chunks=world_size, dim=0)
attention_chunks = torch.chunk(inputs["attention_mask"], chunks=world_size, dim=0)
label_chunks = torch.chunk(labels, chunks=world_size, dim=0)

return input_chunks[local_rank], attention_chunks[local_rank], label_chunks[local_rank]
11 changes: 3 additions & 8 deletions tests/nn/data_parallel/test_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pipegoose.nn import DataParallel
from pipegoose.testing.utils import (
calculate_parameter_similarity,
get_microbatch,
init_parallel_context,
skip_if_no_cuda,
spawn,
Expand Down Expand Up @@ -89,13 +90,6 @@ def test_parallelize_a_transformer_and_inference(model, tokenizer, data_parallel
def run_backward_a_parallelized_transformers(
rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size, kwargs
):
def get_microbatch(inputs, labels):
local_rank = parallel_context.get_local_rank(ParallelMode.DATA)
input_chunks = torch.chunk(inputs["input_ids"], chunks=world_size, dim=0)
attention_chunks = torch.chunk(inputs["attention_mask"], chunks=world_size, dim=0)
label_chunks = torch.chunk(labels, chunks=world_size, dim=0)
return input_chunks[local_rank], attention_chunks[local_rank], label_chunks[local_rank]

model = deepcopy(kwargs["model"])
UPDATED_MODEL = deepcopy(kwargs["updated_model"])
LR = kwargs["lr"]
Expand All @@ -106,7 +100,8 @@ def get_microbatch(inputs, labels):
rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size
)

input_ids, attention_mask, labels = get_microbatch(inputs, labels)
# NOTE: each model replicas only train on a subset of data
input_ids, attention_mask, labels = get_microbatch(inputs, labels, parallel_context, ParallelMode.DATA)
parallelized_model = DataParallel(model, parallel_context).parallelize()
optim = SGD(parallelized_model.parameters(), lr=LR)

Expand Down
23 changes: 12 additions & 11 deletions tests/nn/expert_parallel/test_hybrid_expert_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from pipegoose.nn.data_parallel.data_parallel import DataParallel
from pipegoose.nn.expert_parallel.loss import ExpertLoss
from pipegoose.nn.expert_parallel.routers import SwitchNoisePolicy, Top1Router
from pipegoose.testing.utils import init_parallel_context, spawn
from pipegoose.testing.utils import get_microbatch, init_parallel_context, spawn

MODEL_NAME = "bigscience/bloom-560m"

Expand Down Expand Up @@ -57,15 +57,20 @@ def run_expert_parallel_with_data_parallel(
pipeline_parallel_size,
data_parallel_size,
)
# NOTE: each model replicas only train on a subset of data
input_ids, attention_mask, labels = get_microbatch(
kwargs["input"], kwargs["labels"], parallel_context, ParallelMode.EXPERT
)
loss_func = ExpertLoss(nn.CrossEntropyLoss())

model = ExpertParallel(model, NUM_EXPERTS, mapping=mapping, router=router, parallel_context=parallel_context).parallelize()
model = DataParallel(model, parallel_context).parallelize()
loss_func = ExpertLoss(nn.CrossEntropyLoss())
optim = Adam(model.parameters(), lr=1e-3)

outputs = model(**kwargs["input"])
outputs = model(input_ids=input_ids, attention_mask=attention_mask)

logits = outputs.logits[..., :-1, :].view(-1, outputs.logits.shape[-1])
labels = kwargs["labels"][..., 1:].view(-1).to(logits.device)
labels = labels[..., 1:].view(-1).to(logits.device)
loss = loss_func(logits, labels)

optim.zero_grad()
Expand All @@ -76,11 +81,7 @@ def run_expert_parallel_with_data_parallel(
expert_grads = torch.chunk(expert_grads, chunks=data_parallel_size, dim=0)

# NOTE: check if expert grads are the same across data parallel dimension
assert all(
torch.all(torch.eq(expert_grads[i], expert_grads[j]))
for i in range(len(expert_grads))
for j in range(i + 1, len(expert_grads))
)
assert torch.allclose(*expert_grads)

optim.step()

Expand All @@ -100,8 +101,8 @@ def test_expert_parallel_with_data_parallel(model, tokenizer):
noise_policy = SwitchNoisePolicy()
router = Top1Router(noise_policy, NUM_EXPERTS, D_MODEL)

text = "Persistence is all you need."
input = tokenizer(text, return_tensors="pt")
text = ["Persistence is all you need.", "Attention is all you need."]
input = tokenizer(text, return_tensors="pt", padding=True)

kwargs = {
"input": input,
Expand Down

0 comments on commit e094986

Please sign in to comment.