databricks / megablocks

Apache License 2.0
1.2k stars 174 forks source link

Routing #118

Open alexliap opened 4 months ago

alexliap commented 4 months ago

Is the router implemented the noisy top k routing suggested by the OUTRAGEOUSLY LARGE NEURAL NETWORKS: THE SPARSELY-GATED MIXTURE-OF-EXPERTS LAYER paper?

In the router code you seem to apply the noise at the input of the router and not at the router scores like in the paper above:

 def forward(self, x):
        if self.training and self.args.moe_jitter_eps is not None:
            x = x * self.jitter(x)

        scores = self.layer(x.view(-1, x.shape[-1])).softmax(dim=-1)
        expert_weights, expert_indices = self._top_k(scores)
        if self.args.moe_normalize_expert_weights:
            expert_weights = expert_weights / torch.norm(
                expert_weights, p=self.args.moe_normalize_expert_weights,dim=-1, keepdim=True)

        expert_indices = (
            _uniform_expert_assignment(expert_indices, self.args.moe_num_experts)
            if self.args.uniform_expert_assignment else expert_indices
        )
        return scores, expert_weights, expert_indices

In the aforementioned paper the noisy top k works like: image

Is this somehting equivalent? I am not trying to argue that it is wrong, but i was just trying to figure out if this is the same.

mvpatel2000 commented 4 months ago

@tgale96 what do you think since you implemented this? It does seem different to me but not sure if it was pulled from some other paper