diff --git a/moe_peft/modules/lora_moes.py b/moe_peft/modules/lora_moes.py index f105a49..e92550a 100644 --- a/moe_peft/modules/lora_moes.py +++ b/moe_peft/modules/lora_moes.py @@ -1,3 +1,4 @@ +import logging import math from typing import List, Optional, Tuple @@ -14,6 +15,7 @@ MixtralSparseMoe, SwitchRouterLoss, SwitchSparseMoe, + _entropy, ) @@ -58,6 +60,9 @@ def forward( route_weight = torch.nn.functional.softmax( self.gate_(hidden_states.to(self.dtype_)), dim=-1, dtype=torch.float32 ) + if self.router_profile_: + logging.info(f"entropy: {_entropy(route_weight)}") + for expert_idx in range(self.experts_): expert_lora = lora_linear.loras_[ f"moe.{self.adapter_name_}.experts.{expert_idx}" @@ -114,6 +119,9 @@ def forward( hidden_states = hidden_states.view(-1, hidden_dim).to(self.dtype_) router_logits = self.gate_(hidden_states) routing_weights_before = F.softmax(router_logits, dim=1, dtype=self.dtype_) + if self.router_profile_: + logging.info(f"entropy: {_entropy(routing_weights_before)}") + routing_weights, selected_experts = torch.topk( routing_weights_before, self.topk_, dim=-1 ) @@ -151,6 +159,7 @@ def forward( return residual + final_hidden_states +@torch.jit.script def _dynamic_routing( router_logits: torch.Tensor, broadcast_threshhold: float, @@ -158,8 +167,7 @@ def _dynamic_routing( eps: float = 1e-5, ): # calculate router entropy - probs_neg_log = -torch.log(router_logits + eps) # eps for 'p=0, -plogp=0' - router_entropy = (router_logits * probs_neg_log).sum(dim=-1) + router_entropy = _entropy(router_logits, -1, eps) # broadcast if higher than threshhold broadcast_index, _ = torch.where(router_entropy >= broadcast_threshhold) # calculate top-p routing @@ -170,7 +178,7 @@ def _dynamic_routing( threshold_indices = expert_mask.long().argmax(dim=-1) threshold_mask = torch.nn.functional.one_hot( threshold_indices, num_classes=router_logits.size(-1) - ).bool() + ).to(torch.bool) # calculate final mask expert_mask = (expert_mask & ~threshold_mask).index_fill(0, broadcast_index, False) sorted_logits = sorted_logits.masked_fill(expert_mask, 0.0) @@ -225,11 +233,11 @@ def forward( routing_weights, self.broadcast_threshhold_, self.top_p_, self.eps_ ) if self.router_profile_: - print(f"entropy: {router_entropy.mean()}") + logging.info(f"entropy: {router_entropy.mean()}") router_profile = (routing_weights > 0.0).long().sum(-1).float() - print(f"max activated: {router_profile.max()}") - print(f"min activated: {router_profile.min()}") - print(f"avg activated: {router_profile.mean()}") + logging.info(f"max activated: {router_profile.max()}") + logging.info(f"min activated: {router_profile.min()}") + logging.info(f"avg activated: {router_profile.mean()}") for expert_idx in range(self.experts_): expert_lora = lora_linear.loras_[ diff --git a/moe_peft/modules/mix_lora.py b/moe_peft/modules/mix_lora.py index 46237f1..edfab38 100644 --- a/moe_peft/modules/mix_lora.py +++ b/moe_peft/modules/mix_lora.py @@ -8,6 +8,16 @@ from .config import MixLoraConfig +@torch.jit.script +def _entropy( + logits: torch.Tensor, + dim: int = -1, + eps: float = 1e-5, +) -> torch.Tensor: + probs_neg_log = -torch.log(logits + eps) # eps for 'p=0, -plogp=0' + return (logits * probs_neg_log).sum(dim=dim) + + def _slice_tensor( data: torch.Tensor, slice: torch.Tensor, @@ -263,6 +273,7 @@ def forward( return final_hidden_states, router_logits +@torch.jit.script def _dynamic_routing( router_logits: torch.Tensor, broadcast_threshhold: float = 2.0, @@ -270,8 +281,7 @@ def _dynamic_routing( eps: float = 1e-5, ): # calculate router entropy - probs_neg_log = -torch.log(router_logits + eps) # eps for 'p=0, -plogp=0' - router_entropy = (router_logits * probs_neg_log).sum(dim=-1) + router_entropy = _entropy(router_logits, -1, eps) # broadcast if higher than threshhold broadcast_index = torch.nonzero(router_entropy >= broadcast_threshhold).squeeze(-1) # calculate top-p routing @@ -282,7 +292,7 @@ def _dynamic_routing( threshold_indices = expert_mask.long().argmax(dim=-1) threshold_mask = torch.nn.functional.one_hot( threshold_indices, num_classes=sorted_indices.size(-1) - ).bool() + ).to(torch.bool) # calculate final mask expert_mask = (expert_mask & ~threshold_mask).index_fill(0, broadcast_index, False) # sorted_logits = sorted_logits.masked_fill(expert_mask, 0.0)