diff --git a/pipegoose/testing/utils.py b/pipegoose/testing/utils.py index 7542497..097e777 100644 --- a/pipegoose/testing/utils.py +++ b/pipegoose/testing/utils.py @@ -2,7 +2,7 @@ import random import socket from functools import partial -from typing import Callable +from typing import Callable, Tuple import pytest import torch @@ -118,3 +118,16 @@ def calculate_parameter_similarity(module1: nn.Module, module2: nn.Module, rtol: def count_model_parameters(model): return sum(p.numel() for p in model.parameters()) + + +def get_microbatch( + inputs, labels, parallel_context: ParallelContext, parallel_mode: ParallelMode +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + local_rank = parallel_context.get_local_rank(parallel_mode) + world_size = parallel_context.get_world_size(parallel_mode) + + input_chunks = torch.chunk(inputs["input_ids"], chunks=world_size, dim=0) + attention_chunks = torch.chunk(inputs["attention_mask"], chunks=world_size, dim=0) + label_chunks = torch.chunk(labels, chunks=world_size, dim=0) + + return input_chunks[local_rank], attention_chunks[local_rank], label_chunks[local_rank] diff --git a/tests/nn/data_parallel/test_data_parallel.py b/tests/nn/data_parallel/test_data_parallel.py index 11bf58c..e296379 100644 --- a/tests/nn/data_parallel/test_data_parallel.py +++ b/tests/nn/data_parallel/test_data_parallel.py @@ -9,6 +9,7 @@ from pipegoose.nn import DataParallel from pipegoose.testing.utils import ( calculate_parameter_similarity, + get_microbatch, init_parallel_context, skip_if_no_cuda, spawn, @@ -89,13 +90,6 @@ def test_parallelize_a_transformer_and_inference(model, tokenizer, data_parallel def run_backward_a_parallelized_transformers( rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size, kwargs ): - def get_microbatch(inputs, labels): - local_rank = parallel_context.get_local_rank(ParallelMode.DATA) - input_chunks = torch.chunk(inputs["input_ids"], chunks=world_size, dim=0) - attention_chunks = torch.chunk(inputs["attention_mask"], chunks=world_size, dim=0) - label_chunks = torch.chunk(labels, chunks=world_size, dim=0) - return input_chunks[local_rank], attention_chunks[local_rank], label_chunks[local_rank] - model = deepcopy(kwargs["model"]) UPDATED_MODEL = deepcopy(kwargs["updated_model"]) LR = kwargs["lr"] @@ -106,7 +100,8 @@ def get_microbatch(inputs, labels): rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size ) - input_ids, attention_mask, labels = get_microbatch(inputs, labels) + # NOTE: each model replicas only train on a subset of data + input_ids, attention_mask, labels = get_microbatch(inputs, labels, parallel_context, ParallelMode.DATA) parallelized_model = DataParallel(model, parallel_context).parallelize() optim = SGD(parallelized_model.parameters(), lr=LR) diff --git a/tests/nn/expert_parallel/test_hybrid_expert_parallel.py b/tests/nn/expert_parallel/test_hybrid_expert_parallel.py index 045dce2..facc6bd 100644 --- a/tests/nn/expert_parallel/test_hybrid_expert_parallel.py +++ b/tests/nn/expert_parallel/test_hybrid_expert_parallel.py @@ -13,7 +13,7 @@ from pipegoose.nn.data_parallel.data_parallel import DataParallel from pipegoose.nn.expert_parallel.loss import ExpertLoss from pipegoose.nn.expert_parallel.routers import SwitchNoisePolicy, Top1Router -from pipegoose.testing.utils import init_parallel_context, spawn +from pipegoose.testing.utils import get_microbatch, init_parallel_context, spawn MODEL_NAME = "bigscience/bloom-560m" @@ -57,15 +57,20 @@ def run_expert_parallel_with_data_parallel( pipeline_parallel_size, data_parallel_size, ) + # NOTE: each model replicas only train on a subset of data + input_ids, attention_mask, labels = get_microbatch( + kwargs["input"], kwargs["labels"], parallel_context, ParallelMode.EXPERT + ) + loss_func = ExpertLoss(nn.CrossEntropyLoss()) + model = ExpertParallel(model, NUM_EXPERTS, mapping=mapping, router=router, parallel_context=parallel_context).parallelize() model = DataParallel(model, parallel_context).parallelize() - loss_func = ExpertLoss(nn.CrossEntropyLoss()) optim = Adam(model.parameters(), lr=1e-3) - outputs = model(**kwargs["input"]) + outputs = model(input_ids=input_ids, attention_mask=attention_mask) logits = outputs.logits[..., :-1, :].view(-1, outputs.logits.shape[-1]) - labels = kwargs["labels"][..., 1:].view(-1).to(logits.device) + labels = labels[..., 1:].view(-1).to(logits.device) loss = loss_func(logits, labels) optim.zero_grad() @@ -76,11 +81,7 @@ def run_expert_parallel_with_data_parallel( expert_grads = torch.chunk(expert_grads, chunks=data_parallel_size, dim=0) # NOTE: check if expert grads are the same across data parallel dimension - assert all( - torch.all(torch.eq(expert_grads[i], expert_grads[j])) - for i in range(len(expert_grads)) - for j in range(i + 1, len(expert_grads)) - ) + assert torch.allclose(*expert_grads) optim.step() @@ -100,8 +101,8 @@ def test_expert_parallel_with_data_parallel(model, tokenizer): noise_policy = SwitchNoisePolicy() router = Top1Router(noise_policy, NUM_EXPERTS, D_MODEL) - text = "Persistence is all you need." - input = tokenizer(text, return_tensors="pt") + text = ["Persistence is all you need.", "Attention is all you need."] + input = tokenizer(text, return_tensors="pt", padding=True) kwargs = { "input": input,