def _top_k(self, scores):
if self.args.moe_top_k == 1:
return scores.max(dim=-1) # <-- causes weight shape to become [S]
return torch.topk(scores, self.args.moe_top_k, dim=-1) # <-- shape is normally [S,K]
caused expert weight norm to be calculated wrong:
expert_weights, expert_indices = self._top_k(scores)
if self.args.moe_normalize_expert_weights:
# this function expects dim=-1 to only contain a single token's weights
expert_weights = expert_weights / torch.norm(
expert_weights, p=self.args.moe_normalize_expert_weights,dim=-1, keepdim=True)
After this PR, top-1 models with moe_normalize_expert_weights=1 should always have the final weights become 1 (where previously they would be divided weirdly)
The router.py function,
caused expert weight norm to be calculated wrong:
After this PR, top-1 models with moe_normalize_expert_weights=1 should always have the final weights become 1 (where previously they would be divided weirdly)