huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
133.24k stars 26.61k forks source link

Jitter Noise added to input being passed to experts in Switch Transformers #33969

Open karan-uppal3 opened 6 days ago

karan-uppal3 commented 6 days ago

System Info

System Info

Who can help?

@ArthurZucker @younesbelkada

Information

Tasks

Reproduction

import torch
import torch.nn as nn
from transformers import (
    SwitchTransformersConfig,
    SwitchTransformersTop1Router,
)
from transformers.models.switch_transformers.modeling_switch_transformers import SwitchTransformersDenseActDense

class MySwitchTransformersSparseMLP(nn.Module):
    r"""
    Implementation of the Switch Transformers Sparse MLP module.
    """

    def __init__(self, config: SwitchTransformersConfig, expert_class: nn.Module = SwitchTransformersDenseActDense):
        super().__init__()
        # Step 1: Get the correct router according to its class
        self.router = SwitchTransformersTop1Router(config)

        # Step 2: Get the experts
        self.experts = nn.ModuleDict()
        for idx in range(config.num_experts):
            self.experts[f"expert_{idx}"] = expert_class(config)

    def forward(self, hidden_states):
        r"""
        Hold on, this will be slightly tricky to understand In the correct order, a MoE layer does the following:

        1- Gets the `router_mask` from the router. The shape of the mask is `(batch_size, sequence_length, num_expert)`
        and corresponds to the argmax of the `router_probs`. The probabilities are needed in the computation of the
        hidden states : they are broadcasted to the hidden states values (can be interpreted as a scaling factor).

        2- Dispatch the tokens to its associated experts. We do a classic for loop over the experts and assign for each
        expert the corresponding hidden states.

        """

        prev_save = hidden_states.clone()

        # Step 1: Get the router_mask from the router as wel as the probabilities
        router_mask, router_probs, router_logits = self.router(hidden_states)
        expert_index = torch.argmax(router_mask, dim=-1)

        print(torch.allclose(prev_save, hidden_states))
        print(torch.mean(prev_save - hidden_states))

        # The routers introduced might not always map all the tokens, to a router, which means that some hidden states
        # can be unchanged from one layer to another. That is why the hidden states are cloned before updating only the seleced ones.

        next_states = hidden_states.clone()

        router_mask = router_mask.bool()
        batch_size, seq_len, num_experts = router_mask.shape
        idx_mask = router_mask.transpose(1, 2).reshape(batch_size * seq_len, num_experts).sum(dim=0)
        idx_mask = torch.nonzero(idx_mask, as_tuple=True)[
            0
        ].tolist()  # length: number of "activated" expert / value: index
        for idx in idx_mask:
            next_states[router_mask[:, :, idx]] = getattr(self.experts, "expert_{}".format(idx))(
                hidden_states[router_mask[:, :, idx]]
            )

        hidden_states = router_probs * next_states
        return hidden_states, (router_logits, expert_index)

config = SwitchTransformersConfig()
model = MySwitchTransformersSparseMLP(config)

model.train()
in_data = torch.ones(1, 1, 768)
out = model(in_data)

The output is

False
tensor(-0.0001)

which ideally should give True and the mean difference should be zero.

This is because in SwitchTransformersTop1Router, the hidden_states are multiplied with jitter noise which persists even when you pass it to the experts.

https://github.com/huggingface/transformers/blob/e71a01a104dd663c730e494eb0b6467bb51df357/src/transformers/models/switch_transformers/modeling_switch_transformers.py#L159-L161

Expected behavior

Ideally, no jitter noise should be present when passing the input to the experts, returning True and the mean difference as 0.

ArthurZucker commented 6 days ago

Hey! where is the ideal scenario coming from? 🤗 we tried to follow the original implementation on this!

karan-uppal3 commented 4 days ago

Hello @ArthurZucker! According to the pseudo code given in Figure 15 and Figure 16 of the original paper, the input to the experts doesn't contain the additional jitter noise.