Closed luchris429 closed 8 months ago
Hi, what is the model size and what is your current perf gap?
Message ID: @.***>
~44M parameters. GLA is getting ~4 bpc (~2.7 nats), whereas RMT and Transformer can get ~1.4 bpc (~1.0 nats).
I think the performance gap is big enough to suggest that something is very clearly not working though.
(Just launched another run that should have minimal differences in training setup -- can see that there's already significant divergence in training despite similar param count).
what is your triton version? can you pass this test
Ahh that might be it! I'm on Triton 2.1.0
Will upgrade triton and get back to you.
Just upgraded -- Is this diff more reasonable or still too off?
Looks normal! I think after the triton version update your gla training will be good
Also for smaller model, I would recommend the parameter allocation that is used in RetNet. i.e., d_key = d_model, d_value = 2*d_model. FFN expansion=2.
@luchris429 Hi, please refer to this fix commit (https://github.com/sustcsonglin/flash-linear-attention/commit/84a3940ab99bccdf1e64e040b78f65d110516365). Sorry for wasting your time.
It's doing better than Transformer now! Thanks so much @yzhangcs @sustcsonglin ! Great work.
Hello,
I've been playing with this architecture on nanoGPT. While I can get other architectures to play nicely there (e.g. RMT), I'm really struggling to get GLA to perform well.
Do you have any tips or code for training? For example, do you have a repository you recommend or key hyperparameter differences to normal transformers?
Thanks!