Skip to content

Commit

Permalink
improve efficiency
Browse files Browse the repository at this point in the history
  • Loading branch information
mikecovlee committed Sep 12, 2024
1 parent 3e71593 commit defd4ec
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 10 deletions.
22 changes: 15 additions & 7 deletions moe_peft/modules/lora_moes.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import math
from typing import List, Optional, Tuple

Expand All @@ -14,6 +15,7 @@
MixtralSparseMoe,
SwitchRouterLoss,
SwitchSparseMoe,
_entropy,
)


Expand Down Expand Up @@ -58,6 +60,9 @@ def forward(
route_weight = torch.nn.functional.softmax(
self.gate_(hidden_states.to(self.dtype_)), dim=-1, dtype=torch.float32
)
if self.router_profile_:
logging.info(f"entropy: {_entropy(route_weight)}")

for expert_idx in range(self.experts_):
expert_lora = lora_linear.loras_[
f"moe.{self.adapter_name_}.experts.{expert_idx}"
Expand Down Expand Up @@ -114,6 +119,9 @@ def forward(
hidden_states = hidden_states.view(-1, hidden_dim).to(self.dtype_)
router_logits = self.gate_(hidden_states)
routing_weights_before = F.softmax(router_logits, dim=1, dtype=self.dtype_)
if self.router_profile_:
logging.info(f"entropy: {_entropy(routing_weights_before)}")

routing_weights, selected_experts = torch.topk(
routing_weights_before, self.topk_, dim=-1
)
Expand Down Expand Up @@ -151,15 +159,15 @@ def forward(
return residual + final_hidden_states


@torch.jit.script
def _dynamic_routing(
router_logits: torch.Tensor,
broadcast_threshhold: float,
top_p: float,
eps: float = 1e-5,
):
# calculate router entropy
probs_neg_log = -torch.log(router_logits + eps) # eps for 'p=0, -plogp=0'
router_entropy = (router_logits * probs_neg_log).sum(dim=-1)
router_entropy = _entropy(router_logits, -1, eps)
# broadcast if higher than threshhold
broadcast_index, _ = torch.where(router_entropy >= broadcast_threshhold)
# calculate top-p routing
Expand All @@ -170,7 +178,7 @@ def _dynamic_routing(
threshold_indices = expert_mask.long().argmax(dim=-1)
threshold_mask = torch.nn.functional.one_hot(
threshold_indices, num_classes=router_logits.size(-1)
).bool()
).to(torch.bool)
# calculate final mask
expert_mask = (expert_mask & ~threshold_mask).index_fill(0, broadcast_index, False)
sorted_logits = sorted_logits.masked_fill(expert_mask, 0.0)
Expand Down Expand Up @@ -225,11 +233,11 @@ def forward(
routing_weights, self.broadcast_threshhold_, self.top_p_, self.eps_
)
if self.router_profile_:
print(f"entropy: {router_entropy.mean()}")
logging.info(f"entropy: {router_entropy.mean()}")
router_profile = (routing_weights > 0.0).long().sum(-1).float()
print(f"max activated: {router_profile.max()}")
print(f"min activated: {router_profile.min()}")
print(f"avg activated: {router_profile.mean()}")
logging.info(f"max activated: {router_profile.max()}")
logging.info(f"min activated: {router_profile.min()}")
logging.info(f"avg activated: {router_profile.mean()}")

for expert_idx in range(self.experts_):
expert_lora = lora_linear.loras_[
Expand Down
16 changes: 13 additions & 3 deletions moe_peft/modules/mix_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,16 @@
from .config import MixLoraConfig


@torch.jit.script
def _entropy(
logits: torch.Tensor,
dim: int = -1,
eps: float = 1e-5,
) -> torch.Tensor:
probs_neg_log = -torch.log(logits + eps) # eps for 'p=0, -plogp=0'
return (logits * probs_neg_log).sum(dim=dim)


def _slice_tensor(
data: torch.Tensor,
slice: torch.Tensor,
Expand Down Expand Up @@ -263,15 +273,15 @@ def forward(
return final_hidden_states, router_logits


@torch.jit.script
def _dynamic_routing(
router_logits: torch.Tensor,
broadcast_threshhold: float = 2.0,
top_p: float = 0.8,
eps: float = 1e-5,
):
# calculate router entropy
probs_neg_log = -torch.log(router_logits + eps) # eps for 'p=0, -plogp=0'
router_entropy = (router_logits * probs_neg_log).sum(dim=-1)
router_entropy = _entropy(router_logits, -1, eps)
# broadcast if higher than threshhold
broadcast_index = torch.nonzero(router_entropy >= broadcast_threshhold).squeeze(-1)
# calculate top-p routing
Expand All @@ -282,7 +292,7 @@ def _dynamic_routing(
threshold_indices = expert_mask.long().argmax(dim=-1)
threshold_mask = torch.nn.functional.one_hot(
threshold_indices, num_classes=sorted_indices.size(-1)
).bool()
).to(torch.bool)
# calculate final mask
expert_mask = (expert_mask & ~threshold_mask).index_fill(0, broadcast_index, False)
# sorted_logits = sorted_logits.masked_fill(expert_mask, 0.0)
Expand Down

0 comments on commit defd4ec

Please sign in to comment.