Jamie-Stirling / RetNet

An implementation of "Retentive Network: A Successor to Transformer for Large Language Models"
MIT License
1.16k stars 100 forks source link

Is Retnet equivalent to ordinary GPT when the decay is set to 1 ? #37

Open xuanyaoming opened 8 months ago

xuanyaoming commented 8 months ago

I'm a little confused of what retnet does in practice. Because in the formula Rentention(X) = (Q @ K.T * D) @ V, if the decay is 1, the mathematical derivation of proving the equivalence between RNN and the Retnet's transformer still works. As when decay is equal to 1, D will be the normal attention mask used by almost all existing GPT models. Does that mean all existing GPT models can be modified into Retnet by simply modifying the inference function without any further training? Am I correct or do I miss something?

yunusskeete commented 5 months ago

I am certainly not an expert but I believe the motivation behind RetNets is that nonlinear functions (like the softmax in attention) can’t be written in recurrent form because nonlinear functions aren’t commutative or associative - i.e. order of operations matters and you can’t just chunk up the calculation and sum the results (the exact idea behind the recurrent/chunkwise-recurrent RetNet). The Transformers behind GPT use the softmax nonlinearity to normalise the attention scores and introduce nonlinearities to capture more powerful relationships. RetNets remove the softmax, removing the nonlinearity, and use a group norm for normalisation instead. Now we can chunk up the calculation and sum these chunks (introducing recurrence) to achieve the same results as the parallel implementation.

Retention introduces a decay within the mask, where attention does not. Setting the decay to 1 just leaves the causal mask. With the parallel implementation, a RetNet without decay is effectively just a linear Transformer (without the sigmoid and with a group norm). Attention(X) = softmax(Q @ K.T * D) @ V whereas Retention(X) = (Q @ K.T * D) @ V, where D is a causal mask. These changes are minimal, but probably means that Transformers can’t easily be converted into RetNets without retraining.

I found this YouTube video very useful for understanding RetNets, and I found this implementation very useful for understanding how Transformers are implemented.

xuanyaoming commented 5 months ago

RetNets remove the softmax, removing the nonlinearity, and use a group norm for normalisation instead.

Thanks. I realized the same thing when I was tring to recreate Retnet only based on their paper. I agree that there is no way to easily convert an existing GPT model to a retnet model. If there were a method, all leading companies would have used it in their products by now. But I haven't heard any news yet.

Speaking of the effect of the nonlinearties, is there an official Retnet repo contributed by the same team who wrote the paper? I'm very curious about their test results and how good Retnet is in practice. This repo only has basic retnet structures.

yunusskeete commented 5 months ago

Speaking of the effect of the nonlinearties, is there an official Retnet repo contributed by the same team who wrote the paper? I'm very curious about their test results and how good Retnet is in practice. This repo only has basic retnet structures.

I believe the torchscale repo is the official implementation, although I didn't find this implementation helpful at all. I have found fkodom's implementation far more useful.