mistralai / mistral-inference

Official inference library for Mistral models
https://mistral.ai/
Apache License 2.0
9.16k stars 804 forks source link

Question about Mixtral MLP section #139

Closed lhallee closed 1 month ago

lhallee commented 3 months ago

Hello,

Great work! Is it okay to say it is just a standard vanilla MLP block? According to the huggingface implementation there is an additional third linear layer and added elementwise multiplication.

image

class MixtralBlockSparseTop2MLP(nn.Module):
    def __init__(self, config: MixtralConfig):
        super().__init__()
        self.ffn_dim = config.intermediate_size
        self.hidden_dim = config.hidden_size

        self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
        self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
        self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) # not standard

        self.act_fn = ACT2FN[config.hidden_act]

    def forward(self, hidden_states):
        current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states) # not standard
        current_hidden_states = self.w2(current_hidden_states)
        return current_hidden_states

I think this has been confusing to some readers, but perhaps this has been used before and I am unaware. Is there any insights you guys can offer about why this layer was added? It seems to add more expressiveness to the experts but I didn't know if you had experimented with and without it.

rookie-joe commented 1 month ago

a normal swiglue here (mlp)

lhallee commented 1 month ago

a normal swiglue here (mlp)

This is showing up more often but using the w3 is definitely not the norm?

rookie-joe commented 1 month ago

a normal swiglue here (mlp)

This is showing up more often but using the w3 is definitely not the norm?

I mean, it is a normal, i.e., vanilla swigule here, not a norm

lhallee commented 1 month ago

a normal swiglue here (mlp)

This is showing up more often but using the w3 is definitely not the norm?

I mean, it is a normal, i.e., vanilla swigule here, not a norm

I meant "normal" not norm, sorry. Where is a swiglue mentioned in papers? Most transformers do not have three Linear layers in the MLP, including the original / vanilla transformer.