Skip to content

Commit

Permalink
add router loss to DynMoLE
Browse files Browse the repository at this point in the history
  • Loading branch information
mikecovlee committed Sep 12, 2024
1 parent defd4ec commit 103eea7
Show file tree
Hide file tree
Showing 3 changed files with 146 additions and 11 deletions.
25 changes: 19 additions & 6 deletions moe_peft/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,13 +442,26 @@ def _call_decoder_stack(
cache_position,
past_key_value,
)
if len(router_logits) == 0:
continue
# collecting router logits
assert len(router_logits) == num_adapters
for idx in range(num_adapters):
if router_logits[idx] is not None:
all_router_logits[idx].append(router_logits[idx])
if len(router_logits) == 0:
attn_proj, mlp_proj = decoder_layer.state_dict()
all_proj = copy.copy(attn_proj)
all_proj.update(mlp_proj)
for idx in range(num_adapters):
adapter_name = input_args.batch_configs_[idx].adapter_name_
for proj in all_proj.values():
if adapter_name in proj.moes_ and hasattr(
proj.moes_[adapter_name], "router_logits_"
):
all_router_logits[idx].append(
proj.moes_[adapter_name].router_logits_
)
proj.moes_[adapter_name].router_logits_ = None
else:
assert len(router_logits) == num_adapters
for idx in range(num_adapters):
if router_logits[idx] is not None:
all_router_logits[idx].append(router_logits[idx])

hidden_states = self.model_.norm(hidden_states)

Expand Down
19 changes: 19 additions & 0 deletions moe_peft/modules/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,9 @@ class DynMoleConfig(LoraConfig):
num_experts_: int = None
router_init_range_: float = None
routing_strategy_: str = "dynmole"
router_aux_loss_coef_: float = None
router_dyn_loss_coef_: float = None
router_loss_: bool = True

def check(self) -> "DynMoleConfig":
super().check()
Expand All @@ -445,6 +448,15 @@ def check(self) -> "DynMoleConfig":
assert (
isinstance(self.router_init_range_, float) and self.router_init_range_ >= 0
)
assert (
isinstance(self.router_aux_loss_coef_, float)
and self.router_aux_loss_coef_ >= 0
)
assert (
isinstance(self.router_dyn_loss_coef_, float)
and self.router_dyn_loss_coef_ >= 0
)
assert isinstance(self.router_loss_, bool)

return self

Expand All @@ -456,6 +468,13 @@ def from_config(config: Dict[str, any]) -> "DynMoleConfig":
top_p_=config.get("top_p", 0.75),
num_experts_=config["num_experts"],
router_init_range_=config.get("router_init_range", 5.0),
router_aux_loss_coef_=config.get(
"router_aux_loss_coef", 0.001
), # for training
router_dyn_loss_coef_=config.get(
"router_dyn_loss_coef", 0.001
), # for training
router_loss_=config.get("router_loss", True),
**LoraConfig.from_config(config).__dict__,
)

Expand Down
113 changes: 108 additions & 5 deletions moe_peft/modules/lora_moes.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,9 @@ def _dynamic_routing(
# calculate router entropy
router_entropy = _entropy(router_logits, -1, eps)
# broadcast if higher than threshhold
broadcast_index, _ = torch.where(router_entropy >= broadcast_threshhold)
broadcast_index = torch.where(router_entropy >= broadcast_threshhold)[0]
# calculate top-p routing
sorted_logits, _ = torch.sort(router_logits, dim=-1, descending=True)
sorted_logits = torch.sort(router_logits, dim=-1, descending=True)[0]
cumulative_probs = sorted_logits.cumsum(dim=-1)
expert_mask = cumulative_probs > top_p
# maintain top-1 if no experts selected
Expand All @@ -182,10 +182,111 @@ def _dynamic_routing(
# calculate final mask
expert_mask = (expert_mask & ~threshold_mask).index_fill(0, broadcast_index, False)
sorted_logits = sorted_logits.masked_fill(expert_mask, 0.0)
# sorted_indices = sorted_indices.masked_fill(expert_mask, -1)
return router_entropy, sorted_logits


def _dynamic_load_balancing_loss_func(
routing_logits: torch.Tensor,
broadcast_threshhold: float,
top_p: float,
eps: float = 1e-5,
aux_loss_coef: float = 0.001,
dyn_loss_coef: float = 0.001,
attention_mask: Optional[torch.Tensor] = None,
) -> float:
routing_weights = F.softmax(routing_logits, dim=-1)

router_entropy, routing_weights = _dynamic_routing(
routing_weights, broadcast_threshhold, top_p, eps
)

entropy_loss = torch.mean(router_entropy, dim=0)

expert_mask = (routing_weights > 0.0).long()

if attention_mask is None:
# Compute the percentage of tokens routed to each experts
tokens_per_expert = torch.mean(expert_mask.float(), dim=0)

# Compute the average probability of routing to these experts
router_prob_per_expert = torch.mean(routing_weights, dim=0)
else:
num_experts = routing_weights.size(-1)
batch_size, sequence_length = attention_mask.shape
num_hidden_layers = routing_weights.shape[0] // (batch_size * sequence_length)

# Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
expert_attention_mask = (
attention_mask[None, :, :, None]
.expand(
(
num_hidden_layers,
batch_size,
sequence_length,
num_experts,
)
)
.reshape(-1, num_experts)
.to(routing_weights.device)
)

# Compute the percentage of tokens routed to each experts
tokens_per_expert = torch.sum(
expert_mask.float() * expert_attention_mask, dim=0
) / torch.sum(expert_attention_mask, dim=0)

# Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
router_per_expert_attention_mask = (
attention_mask[None, :, :, None]
.expand((num_hidden_layers, batch_size, sequence_length, num_experts))
.reshape(-1, num_experts)
.to(routing_weights.device)
)

# Compute the average probability of routing to these experts
router_prob_per_expert = torch.sum(
routing_weights * router_per_expert_attention_mask, dim=0
) / torch.sum(router_per_expert_attention_mask, dim=0)

overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert)

return entropy_loss * dyn_loss_coef + overall_loss * num_experts * aux_loss_coef


def _unpack_router_logits(gate_logits: List[torch.Tensor]):
compute_device = gate_logits[0].device
concatenated_gate_logits = torch.cat(
[
layer_gate.to(compute_device).reshape(-1, layer_gate.size(-1))
for layer_gate in gate_logits
],
dim=0,
)
return concatenated_gate_logits


class DynMoleRouterLoss(torch.nn.Module):
def __init__(self, config: DynMoleConfig) -> None:
super().__init__()
self.aux_loss_coef = config.router_aux_loss_coef_
self.dyn_loss_coef = config.router_dyn_loss_coef_
self.broadcast_threshhold: float = config.broadcast_threshhold_
self.top_p: float = config.top_p_
self.eps: float = config.entropy_eps_

def forward(self, gate_logits, attention_mask) -> torch.Tensor:
concatenated_gate_logits = _unpack_router_logits(gate_logits)
return _dynamic_load_balancing_loss_func(
concatenated_gate_logits,
self.broadcast_threshhold,
self.top_p,
self.eps,
self.aux_loss_coef,
self.dyn_loss_coef,
attention_mask,
)


class DynMole(LLMMoeBlock):
def __init__(
self,
Expand All @@ -209,6 +310,7 @@ def __init__(
self.top_p_: float = config.top_p_
self.eps_: float = config.entropy_eps_
self.experts_: int = config.num_experts_
self.router_logits_: torch.Tensor = None
self.router_profile_: bool = False
self.profiler_: List[int] = None

Expand All @@ -227,8 +329,8 @@ def forward(
lora_linear: Optional[Linear] = None,
) -> Tuple:
assert lora_linear is not None
router_logits = self.gate_(hidden_states.to(self.dtype_))
routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float32)
self.router_logits_ = self.gate_(hidden_states.to(self.dtype_))
routing_weights = F.softmax(self.router_logits_, dim=-1, dtype=torch.float32)
router_entropy, routing_weights = _dynamic_routing(
routing_weights, self.broadcast_threshhold_, self.top_p_, self.eps_
)
Expand All @@ -255,6 +357,7 @@ def forward(
"mixlora": MixtralRouterLoss,
"mixlora-dynamic": DynamicRouterLoss,
"mixlora-switch": SwitchRouterLoss,
"dynmole": DynMoleRouterLoss,
}


Expand Down

0 comments on commit 103eea7

Please sign in to comment.