Skip to content

Commit

Permalink
Merge pull request #51 from xrsrke/feature/moe
Browse files Browse the repository at this point in the history
[BUG] Fix the bug where tokens can't be dispatched when the input has…
  • Loading branch information
xrsrke authored Nov 30, 2023
2 parents 93dfb32 + b0fe0a6 commit bbc46b2
Show file tree
Hide file tree
Showing 5 changed files with 217 additions and 24 deletions.
9 changes: 8 additions & 1 deletion pipegoose/nn/expert_parallel/experts.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,14 @@ def forward(
# how do we detect this and pass the corresponding arguments to the expert?
# For example, hidden_states.shape = (batch_size, seq_len, hidden_size),
# but we need to dispatch the hidden_states to the corresponding expert
expert_output = expert(dispatched_inputs, *args[1][:, indices], **kwargs)

# NOTE: args[0] is the input embeddings
# args[1] is the hidden_states, so we pass the input embeddings along
# with the hidden_states to the expert
selected_embeddings = rearrange(args[1], "batch_size seq_len d_dim -> (batch_size seq_len) d_dim")[indices]
# selected_embeddings = rearrange(selected_embeddings, "(batch_size seq_len) d_dim -> batch_size seq_len d_dim", batch_size=inputs.shape[0])

expert_output = expert(dispatched_inputs, selected_embeddings, **kwargs)
else:
expert_output = expert(dispatched_inputs)

Expand Down
2 changes: 1 addition & 1 deletion pipegoose/nn/expert_parallel/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@


class ExpertLoss:
def __init__(self, loss_func: Callable, aux_weight: float, z_weight: float):
def __init__(self, loss_func: Callable, aux_weight: float = 0.01, z_weight: float = 0.1):
self.loss_func = loss_func
self.aux_weight = aux_weight
self.z_weight = z_weight
Expand Down
23 changes: 7 additions & 16 deletions pipegoose/nn/expert_parallel/routers.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import math
from abc import ABC, abstractmethod
from typing import Tuple, Optional
from dataclasses import dataclass
from typing import Optional, Tuple

import torch
import torch.nn.functional as F
from torch import nn
from torchtyping import TensorType
from dataclasses import dataclass


class RouterExplorationNoisePolicy(ABC):
Expand All @@ -21,7 +21,8 @@ class SwitchNoisePolicy(RouterExplorationNoisePolicy):
Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity
by Fedus et al.
"""
def __init__(self, eps: float=0.1):

def __init__(self, eps: float = 0.1):
super().__init__()
self.eps = eps

Expand Down Expand Up @@ -100,9 +101,7 @@ def _expert_capacity(self, total_tokens: int) -> int:
expert_capacity = math.ceil((total_tokens / self.num_experts) * c)
return expert_capacity

def forward(
self, inputs: TensorType["batch_size", "seq_len", "d_model"]
) -> RouterOutput:
def forward(self, inputs: TensorType["batch_size", "seq_len", "d_model"]) -> RouterOutput:
orig_dtype = inputs.dtype
total_tokens = inputs.shape[0] * inputs.shape[1]

Expand All @@ -129,12 +128,7 @@ def forward(
# we don't limit the capacity of the experts
topk_weight = router_prob * topk_expert_mask
topk_weight = topk_weight.to(orig_dtype)
return RouterOutput(
dispatching_order=topk_expert_mask,
weight=topk_weight,
aux_loss=aux_loss,
z_loss=z_loss
)
return RouterOutput(dispatching_order=topk_expert_mask, weight=topk_weight, aux_loss=aux_loss, z_loss=z_loss)

# limit the number of tokens per expert
position_in_expert = torch.cumsum(topk_expert_mask, dim=0) * topk_expert_mask
Expand All @@ -149,10 +143,7 @@ def forward(
topk_weight = topk_weight.to(orig_dtype)

return RouterOutput(
dispatching_order=capacity_limited_topk_expert_mask,
weight=topk_weight,
aux_loss=aux_loss,
z_loss=z_loss
dispatching_order=capacity_limited_topk_expert_mask, weight=topk_weight, aux_loss=aux_loss, z_loss=z_loss
)


Expand Down
194 changes: 194 additions & 0 deletions tests/convergence/run_ep.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
from copy import deepcopy

import torch
import torch.distributed as dist
from datasets import load_dataset
from torch.optim import SGD
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
)

from pipegoose.distributed.parallel_context import ParallelContext
from pipegoose.distributed.parallel_mode import ParallelMode
from pipegoose.nn import ExpertParallel
from pipegoose.nn.expert_parallel import SwitchNoisePolicy, Top1Router


def get_model_params_size(model, fp_bytes=4):
params_size = 0
for p in model.parameters():
params_size += p.numel()
params_gb = params_size * fp_bytes / 2**30
return params_gb


def set_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)


if __name__ == "__main__":
import wandb

DATA_PARALLEL_SIZE = 1
TENSOR_PARALLEL_SIZE = 2
PIPELINE_PARALLEL_SIZE = 1
MODEL = "bigscience/bloom-560m"
DATASET = "imdb"
NUM_EPOCHS = 2000
LR = 1e-3
SEED = 69
BATCH_SIZE = 4
CONTEXT_LENGTH = 1024

NUM_EXPERTS = 4

torch.cuda.empty_cache()
set_seed(SEED)

print(f"device_count: {torch.cuda.device_count()}")
print(f"is available: {torch.cuda.is_available()}")

parallel_context = ParallelContext.from_torch(
data_parallel_size=DATA_PARALLEL_SIZE,
tensor_parallel_size=TENSOR_PARALLEL_SIZE,
pipeline_parallel_size=PIPELINE_PARALLEL_SIZE,
)
rank = parallel_context.get_global_rank()

print(f"rank={rank}, initialized parallel_context")

train_dataset = load_dataset("imdb", split="train[:130]")
train_dataset = train_dataset.map(lambda x: {"text": x["text"][:10]}) # for demonstration purposes

dp_rank = parallel_context.get_local_rank(ParallelMode.DATA)
train_sampler = DistributedSampler(train_dataset, num_replicas=DATA_PARALLEL_SIZE, rank=dp_rank, seed=SEED)
train_dataloader = DataLoader(
train_dataset, batch_size=BATCH_SIZE // DATA_PARALLEL_SIZE, shuffle=False, sampler=train_sampler
)

val_dataset = load_dataset("imdb", split="test[:130]")
val_dataset = val_dataset.map(lambda x: {"text": x["text"][:10]}) # for demonstration purposes
val_sampler = DistributedSampler(val_dataset, num_replicas=DATA_PARALLEL_SIZE, rank=dp_rank, seed=SEED)
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE // DATA_PARALLEL_SIZE, shuffle=False, sampler=val_sampler)

model = AutoModelForCausalLM.from_pretrained(MODEL)
# config = BloomConfig(n_layer=4)
# model = BloomForCausalLM(config)
ref_model = deepcopy(model)
tokenizer = AutoTokenizer.from_pretrained(MODEL)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.add_special_tokens({"pad_token": "[PAD]"})

print(f"rank={rank}, model size before parallelizing: {round(get_model_params_size(model), 3)} GB")

dist.barrier()

# model = TensorParallel(model, parallel_context).parallelize()
# model = DataParallel(model, parallel_context).parallelize()
D_MODEL = model.config.hidden_size
noise_policy = SwitchNoisePolicy()
router = Top1Router(noise_policy, NUM_EXPERTS, D_MODEL)

model = ExpertParallel(
model, num_experts=NUM_EXPERTS, mapping=[0, 2], router=router, parallel_context=parallel_context
).parallelize()

optim = SGD(model.parameters(), lr=LR)
# optim = DistributedOptimizer(optim, parallel_context)
model.to("cuda")
device = next(model.parameters()).device

print(f"rank={rank}, model size after parallelizing: {round(get_model_params_size(model), 3)} GB")
print(f"rank={rank}, model is moved to device: {device}")

ref_model.to(device)
# if DATA_PARALLEL_SIZE > 1:
# ref_model = torch.nn.parallel.DistributedDataParallel(ref_model, device_ids=[device])

ref_optim = SGD(ref_model.parameters(), lr=LR)

model.train()
ref_model.train()
step = 0
dist.barrier()

if rank == 0:

def get_time_name():
import datetime

today = datetime.datetime.now()
return today.strftime("%d/%m/%Y_%H:%M:%S")

wandb.init(
project="pipegoose",
name=f"{get_time_name()}.test_ep",
config={
"data_parallel_size": DATA_PARALLEL_SIZE,
"tensor_parallel_size": TENSOR_PARALLEL_SIZE,
"pipeline_parallel_size": PIPELINE_PARALLEL_SIZE,
"model": MODEL,
"dataset": DATASET,
"epochs": NUM_EPOCHS,
"learning_rate": LR,
"seed": SEED,
"batch_size": BATCH_SIZE,
"num_experts": NUM_EXPERTS,
},
)

for epoch in range(NUM_EPOCHS):
train_sampler.set_epoch(epoch)
print(f"rank={rank}, epoch={epoch}")

for batch in train_dataloader:
inputs = tokenizer(batch["text"][0], padding=True, truncation=True, max_length=CONTEXT_LENGTH, return_tensors="pt")
inputs = {name: tensor.to(device) for name, tensor in inputs.items()}
labels = inputs["input_ids"]

outputs = model(**inputs, labels=labels)
ref_outputs = ref_model(**inputs, labels=labels)

optim.zero_grad()
outputs.loss.backward()
optim.step()

ref_optim.zero_grad()
ref_outputs.loss.backward()
ref_optim.step()

print(f"epoch={epoch}, step={step}, rank={rank}, train_loss={outputs.loss}, ref_train_loss={ref_outputs.loss}")

if rank == 0:
wandb.log({"train_loss": outputs.loss, "ref_train_loss": ref_outputs.loss, "step": step, "epoch": epoch})

step += 1

model.eval()
ref_model.eval()
dist.barrier()

step = 0
val_sampler.set_epoch(1)

for batch in val_dataloader:
inputs = tokenizer(batch["text"][0], padding=True, truncation=True, max_length=CONTEXT_LENGTH, return_tensors="pt")
inputs = {name: tensor.to(device) for name, tensor in inputs.items()}
labels = inputs["input_ids"]

outputs = model(**inputs, labels=labels)
ref_outputs = ref_model(**inputs, labels=labels)

print(f"rank={rank}, val_loss={outputs.loss}, ref_val_loss={ref_outputs.loss}, step={step}")

if rank == 0:
wandb.log({"val_loss": outputs.loss, "ref_val_loss": ref_outputs.loss, "step": step})

step += 1

wandb.finish()
model.cpu()
13 changes: 7 additions & 6 deletions tests/nn/expert_parallel/test_expert_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,12 +128,12 @@ def log_routed_expert(module, grad_input, grad_output, key):
outputs = model(**kwargs["input"])

assert all(key in outputs for key in ["logits", "past_key_values"])
# NOTE: why so high tolerance?
assert torch.allclose(outputs.logits, REF_LOGITS, rtol=1e-1)
# TODO: fail at tp_size=2, expert_size=4
assert torch.allclose(outputs.logits, REF_LOGITS)

# compute the loss
logits = outputs.logits[..., :-1, :].view(-1, outputs.logits.shape[-1])
labels = kwargs["labels"][..., 1:].view(-1).to(logits.device)
logits = outputs.logits[..., :-1, :].contiguous().view(-1, outputs.logits.shape[-1])
labels = kwargs["labels"].view(-1).to(logits.device)
loss = loss_func(logits, labels)

assert torch.allclose(loss, REF_LOSS)
Expand All @@ -160,19 +160,20 @@ def test_expert_parallel(model, tokenizer, tensor_parallel_size, num_experts):
DATA_PARALLEL_SIZE = 1
WORLD_SIZE = tensor_parallel_size * PIPELINE_PARALLEL_SIZE * DATA_PARALLEL_SIZE

BATCH_SIZE = 4
NUM_LAYERS = model.config.num_hidden_layers
NUM_EXPERT_LAYERS = 2

mapping = [layer_idx for layer_idx in random.sample(range(NUM_LAYERS - 1), NUM_EXPERT_LAYERS)]
router = DummyRouter(num_experts)

text = "Persistence is all you need."
text = ["Persistence is all you need." for _ in range(BATCH_SIZE)]
input = tokenizer(text, return_tensors="pt")
outputs = model(**input, labels=input["input_ids"])

kwargs = {
"input": input,
"labels": input["input_ids"],
"labels": input["input_ids"][..., 1:].contiguous(),
"model": model,
"mapping": mapping,
"num_experts": num_experts,
Expand Down

0 comments on commit bbc46b2

Please sign in to comment.