Skip to content

Commit

Permalink
update dyn mixlora
Browse files Browse the repository at this point in the history
  • Loading branch information
mikecovlee committed Sep 9, 2024
1 parent 3bd0618 commit cd67017
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 11 deletions.
55 changes: 47 additions & 8 deletions moe_peft/modules/mix_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def _dynamic_routing(
probs_neg_log = -torch.log(router_logits + eps) # eps for 'p=0, -plogp=0'
router_entropy = (router_logits * probs_neg_log).sum(dim=-1)
# broadcast if higher than threshhold
broadcast_index = router_entropy >= broadcast_threshhold
broadcast_index = torch.nonzero(router_entropy >= broadcast_threshhold).squeeze(-1)
# calculate top-p routing
sorted_logits, sorted_indices = torch.sort(router_logits, dim=-1, descending=True)
cumulative_probs = sorted_logits.cumsum(dim=-1)
Expand All @@ -284,9 +284,8 @@ def _dynamic_routing(
threshold_indices, num_classes=sorted_indices.size(-1)
).bool()
# calculate final mask
expert_mask = expert_mask & ~threshold_mask
expert_mask[broadcast_index] = False
sorted_logits = sorted_logits.masked_fill(expert_mask, 0.0)
expert_mask = (expert_mask & ~threshold_mask).index_fill(0, broadcast_index, False)
# sorted_logits = sorted_logits.masked_fill(expert_mask, 0.0)
sorted_indices = sorted_indices.masked_fill(expert_mask, -1)
return router_entropy, sorted_logits, sorted_indices

Expand All @@ -296,6 +295,7 @@ def _dynamic_load_balancing_loss_func(
broadcast_threshhold: float = 2.0,
top_p: float = 0.8,
eps: float = 1e-5,
attention_mask: Optional[torch.Tensor] = None,
) -> float:
num_experts = routing_weights.size(-1)

Expand All @@ -316,11 +316,49 @@ def _dynamic_load_balancing_loss_func(

expert_mask = expert_mask.permute(2, 1, 0)

# Compute the percentage of tokens routed to each experts
tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
if attention_mask is None:
# Compute the percentage of tokens routed to each experts
tokens_per_expert = torch.mean(expert_mask.float(), dim=0)

# Compute the average probability of routing to these experts
router_prob_per_expert = torch.mean(routing_weights, dim=0)
else:
batch_size, sequence_length = attention_mask.shape
num_hidden_layers = routing_weights.shape[0] // (batch_size * sequence_length)

# Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
expert_attention_mask = (
attention_mask[None, :, :, None, None]
.expand(
(
num_hidden_layers,
batch_size,
sequence_length,
num_experts,
num_experts,
)
)
.reshape(-1, num_experts, num_experts)
.to(routing_weights.device)
)

# Compute the average probability of routing to these experts
router_prob_per_expert = torch.mean(routing_weights, dim=0)
# Compute the percentage of tokens routed to each experts
tokens_per_expert = torch.sum(
expert_mask.float() * expert_attention_mask, dim=0
) / torch.sum(expert_attention_mask, dim=0)

# Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
router_per_expert_attention_mask = (
attention_mask[None, :, :, None]
.expand((num_hidden_layers, batch_size, sequence_length, num_experts))
.reshape(-1, num_experts)
.to(routing_weights.device)
)

# Compute the average probability of routing to these experts
router_prob_per_expert = torch.sum(
routing_weights * router_per_expert_attention_mask, dim=0
) / torch.sum(router_per_expert_attention_mask, dim=0)

overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
return entropy_loss + overall_loss * num_experts
Expand All @@ -342,6 +380,7 @@ def forward(self, gate_logits, attention_mask) -> torch.Tensor:
self.broadcast_threshhold,
self.top_p,
self.eps,
attention_mask,
)


Expand Down
7 changes: 4 additions & 3 deletions templates/mixlora_dynamic.json
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
"optim": "adamw",
"scheduler_type": "constant",
"warmup_steps": 0,
"lr": 2e-4,
"lr": 1e-4,
"batch_size": 16,
"micro_batch_size": 8,
"micro_batch_size": 4,
"evaluate_batch_size": 16,
"num_epochs": 2,
"r": 16,
Expand All @@ -30,7 +30,8 @@
},
"routing_strategy": "mixlora-dynamic",
"num_experts": 8,
"top_p": 0.8,
"broadcast_threshhold": 1.8,
"top_p": 0.75,
"group_by_length": false
}
]
Expand Down

0 comments on commit cd67017

Please sign in to comment.