databricks / megablocks

Apache License 2.0
1.11k stars 154 forks source link

Fix `moe_normalize_expert_weights` when `top_k=1` #87

Closed 152334H closed 5 months ago

152334H commented 5 months ago

The router.py function,

    def _top_k(self, scores):
        if self.args.moe_top_k == 1:
            return scores.max(dim=-1) # <-- causes weight shape to become [S]
        return torch.topk(scores, self.args.moe_top_k, dim=-1) # <-- shape is normally [S,K]

caused expert weight norm to be calculated wrong:

        expert_weights, expert_indices = self._top_k(scores)
        if self.args.moe_normalize_expert_weights:
            # this function expects dim=-1 to only contain a single token's weights
            expert_weights = expert_weights / torch.norm(
                expert_weights, p=self.args.moe_normalize_expert_weights,dim=-1, keepdim=True)

After this PR, top-1 models with moe_normalize_expert_weights=1 should always have the final weights become 1 (where previously they would be divided weirdly)

tgale96 commented 5 months ago

Thanks for the PR! And great catch on this bug!

tgale96 commented 5 months ago

Thanks for the update! One last small comment and then I think we're ok to merge!

tgale96 commented 5 months ago

Thanks for the contribution!