meta-llama / llama

Inference code for Llama models
Other
55.86k stars 9.51k forks source link

No SILU/GELU/ReLU activation in the Attention block?! #246

Open jxtps opened 1 year ago

jxtps commented 1 year ago

Ok, this is more of a question about transformers in general and not about Llama being different from the standard transformer architecture: why is there no activation on the assembled values, just before the output projection?

Yes, one could argue the Softmax is an activation, but that's more about routing information, i.e. selecting which Values should be propagated to the output, which is very different from "normal" activation. And I get that the out projection doesn't get an activation so that it can both add & subtract from the residual connection.

But once that output has been assembled, it would normally have an activation applied?!

Reading the source code:

    def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):
        bsz, seqlen, _ = x.shape
        xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)

        xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
        xk = xk.view(bsz, seqlen, self.n_local_heads, self.head_dim)
        xv = xv.view(bsz, seqlen, self.n_local_heads, self.head_dim)

        xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)

        self.cache_k = self.cache_k.to(xq)
        self.cache_v = self.cache_v.to(xq)

        self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
        self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv

        keys = self.cache_k[:bsz, : start_pos + seqlen]
        values = self.cache_v[:bsz, : start_pos + seqlen]

        xq = xq.transpose(1, 2)
        keys = keys.transpose(1, 2)
        values = values.transpose(1, 2)
        scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
        if mask is not None:
            scores = scores + mask  # (bs, n_local_heads, slen, cache_len + slen)
        scores = F.softmax(scores.float(), dim=-1).type_as(xq)
        output = torch.matmul(scores, values)  # (bs, n_local_heads, slen, head_dim)
        output = output.transpose(
            1, 2
        ).contiguous().view(bsz, seqlen, -1)
        # output = F.silu(output)  <-- WOULD HAVE EXPECTED ACTIVATION HERE?!
        return self.wo(output)

???

jxtps commented 1 year ago

In fact, if you look at the vanilla Transformer architecture:

    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x

There's literally a single SILU/GELU/ReLU style activation in each of those layers. Just sayin' ;)

unrealwill commented 1 year ago

In a general manner, you want the least amount of activation (aka non-linearities) as possible, as it makes numerical optimization more difficult.

The main reason for the various linear layers inside the attention is mainly to convert from one shape to a different shape.

Ideally you want plenty of layers because that's where you have the parameters you used to store information.

The reason we sometimes need activation, is that if you put two linear layers sequentially without activation in between then they are just equivalent to a single linear layer whose weight is the matrix product of the two linear weights (so you make your problem ill-defined and your weights redundant).

Similarly with convolutions you need an activation in between. Except if you take separable convolutions ConvX ConvY then you don't need in between activation because they don't mix the same dimensions. But If you have ConvX ConvY ConvX then you need an activation, because separable convolution commute so it is equivalent to ConvX ConvX ConvY, and you get redundant weight.

The other non-linearities we sometimes add, are the normalization layers, they help keep the variance of the inputs of the various layers equal to 1 so that it is easier to optimize. But a normalization layer is like a 1 dimension constraint (the length of the vector), applied on a vector of dimension n, where n >>100, so it doesn't impede the flow of information too much.

In the attention layer, sometimes we rather have the normalization by attention head, because this would mean that the attention layer is computing the something like a near-neighbor search with a properly normalized cosine similarity. (And near neighbor search due to the curse of dimensionality work better in lower dimension space and that's why we project q and k to a space with a smaller dimension dim_head). But to keep the rank we use multiple attention heads with hidden_dim=n_heads*dim_head.

Rotary embeddings is a Linear operation in the complex plane (but it doesn't have learnable parameters)

Gated architectures ( y = W1( W2(x) W3(x) ) use the fact that Matrix product and Hadamar product don't simplify, But similarly if you stack them, you need some activation in-between otherwise along the chain of computations there are operations that could be simplified (collapsing 2 linear layer in one) . (Note that if you don't have an activation layer on W2(x) then because multiplication commute, W2 and W3 holds symmetric roles which make the problem ill defined, (If x happened to be from a space where multiplication doesn't commute (for example SO(3) aka quaternions) so that the symmetry between W2 and W3 is broken, then you could remove the non-linearity). But be aware that multiplying things together tend to make things grow exponentially, so that has to be mitigated, (and that's why rotations are great because you can multiply them but they keep the length of the vector constant). (So a gated architecture like y = W1(x)W2(x) where x € so(3)^n andf W1 and W2 € so(3)^(n * n) ) (so(3) being the lie algebra associated to the lie group SO(3) ) can be stacked and doesn't have non-linearity (if your optimizer is designed to do optimization over Lie group (via exponential mapping) ) ).

Similarly we can see that the line 236 non linearity normalization is important to prevent having two consecutive linear layers (with the last linear inside the feed forward of the last layer) https://github.com/facebookresearch/llama/blob/57b0eb62de0636e75af471e49e2f1862d908d9d8/llama/model.py#L234-L237

Projections layer are linear layers but with the added difficulty that they reduce the rank of the matrices, and therefore impede the backward flow of information. Similarly non reversible activations like relu impede the backward flow of information, and may result in some vanishing gradient problems.

The whole game of network desiging is making sure the rank doesn't collapse, while still reducing the dimensions to speed things up but at the same time still providing as much input dimensions to the network so that it doesn't miss some useful to make its decision information, and make the paths of the information has to flow as short as possible.

Ideally you would want the "probability flow" to be able to flow unimpeded in both forward and backward pass, and just mix information in a learnable fashion.

kcarnold commented 1 year ago

Thoughtful comment! This paper argues that a low-rank bias is actually helpful, and that even back-to-back supposedly-extraneous linear layers end up implicitly regularizing networks towards that state.

Equim-chan commented 1 year ago

[...] if you put two linear layers sequentially without activation in between then they are just equivalent to a single linear layer whose weight is the matrix product of the two linear weights (so you make your problem ill-defined and your weights redundant).

Speaking of this, is weight in RMSNorm actually redundant?

https://github.com/facebookresearch/llama/blob/1a240688810f8036049e8da36b073f63d2ac552c/llama/model.py#L43-L45

I found that every call to a RMSNorm in the code is directly followed by linear layers, wouldn't that make those weight redundant since they could be merged to the next linear layer?