syncdoth / RetNet

Huggingface compatible implementation of RetNet (Retentive Networks, https://arxiv.org/pdf/2307.08621.pdf) including parallel, recurrent, and chunkwise forward.
MIT License
226 stars 24 forks source link

Comments on the model #14

Open okpatil4u opened 1 year ago

okpatil4u commented 1 year ago

Hello @syncdoth, great work !

I just wanted understand your view on the original paper. Do you think if this model is as powerful as authors of this paper claim it to be ? Did you find any issues which will hinder the model reaching its full potential ?

syncdoth commented 11 months ago

Hello! I currently do not have a big compute to pretrain the models, so no definite answer from me :(. But one thing I noticed is that the decays used in the paper may hinder long context size. According to my estimations, the model can only look at the final ~3000 tokens. Beyond this point the decay factor just becomes 0, meaning they are just ignored. The paper claims to have trained with sequence length of 8192, which might suggest that my estimations are wrong, but I am not 100% convinced that it would perform as well as Transformers for long sequences.

This is particularly important since the advantages of having an RNN are infinite input seqnence length (in theory) and low inference cost, which makes drastic changes as the sequence gets longer.

okpatil4u commented 11 months ago

Thanks @syncdoth. This is super useful. I am not worried about sequence length. As long as training is at O(N) and recurrent inference is at O(1), this model proves its mettle.

syncdoth commented 11 months ago

Cool! Yeah, hopefully if this model is trained well, it will have great advantage at inference time. Probably be more edge device friendly too.

https://github.com/syncdoth/RetNet/blob/official_implementation/play.ipynb

I double-checked the implementation to see that RNN style inference is indeed O(1).

syncdoth commented 11 months ago

Actually, my bad on the above plot: the effective sequence length differs by the head. The correct plot would be:

output

Theoretically, 8th head can retain some information up to 2^18=262,144 tokens, although the decay magnitude is so small (1.5913e-28) that it wouldn't contain much iformation. Still, it may be able to look at 2^15=32,768 tokens back (comparable to chatgpt-32k), since the decay at this point is 3.3514e-04.