berlino / gated_linear_attention

MIT License
97 stars 2 forks source link

Tips for training from scratch? #8

Closed luchris429 closed 8 months ago

luchris429 commented 8 months ago

Hello,

I've been playing with this architecture on nanoGPT. While I can get other architectures to play nicely there (e.g. RMT), I'm really struggling to get GLA to perform well.

Do you have any tips or code for training? For example, do you have a repository you recommend or key hyperparameter differences to normal transformers?

Thanks!

sustcsonglin commented 8 months ago

Hi, what is the model size and what is your current perf gap?

Message ID: @.***>

luchris429 commented 8 months ago

~44M parameters. GLA is getting ~4 bpc (~2.7 nats), whereas RMT and Transformer can get ~1.4 bpc (~1.0 nats).

I think the performance gap is big enough to suggest that something is very clearly not working though.

(Just launched another run that should have minimal differences in training setup -- can see that there's already significant divergence in training despite similar param count).

Screen Shot 2024-03-04 at 5 08 30 PM
luchris429 commented 8 months ago

Are there any potential sharp bits for installation or setup that you've seen?

If it's useful to other people:

Small repo to test: here.

Code diff for including GLA (with exception of recent updates to nanoGPT): here

sustcsonglin commented 8 months ago

what is your triton version? can you pass this test

luchris429 commented 8 months ago

Ahh that might be it! I'm on Triton 2.1.0

Screen Shot 2024-03-04 at 6 35 39 PM

Will upgrade triton and get back to you.

luchris429 commented 8 months ago

Just upgraded -- Is this diff more reasonable or still too off?

Screen Shot 2024-03-04 at 6 39 38 PM
sustcsonglin commented 8 months ago

Looks normal! I think after the triton version update your gla training will be good

sustcsonglin commented 8 months ago

Also for smaller model, I would recommend the parameter allocation that is used in RetNet. i.e., d_key = d_model, d_value = 2*d_model. FFN expansion=2.

yzhangcs commented 8 months ago

@luchris429 Hi, please refer to this fix commit (https://github.com/sustcsonglin/flash-linear-attention/commit/84a3940ab99bccdf1e64e040b78f65d110516365). Sorry for wasting your time.

luchris429 commented 8 months ago

It's doing better than Transformer now! Thanks so much @yzhangcs @sustcsonglin ! Great work.

Screen Shot 2024-03-04 at 8 16 27 PM