idiap / sigma-gpt

σ-GPT: A New Approach to Autoregressive Models
59 stars 8 forks source link

Can we have flash attention? #2

Open chris-aeviator opened 2 months ago

chris-aeviator commented 2 months ago

Contrary to nanoGPT the sigmaGPT code base uses the "unoptimized" attention (cannot be fused by torch.compile).

The original code replaces the attention calculation with a call to scaled_dot_product_attention which in turn uses flash attention.

Are there practical reasons for this?

Using flash attention would reduce memory requirements and potentially speed up training.

ArnaudPannatier commented 2 months ago

You can easily turn it back by copying model.py l.62-71 to sigmagpt.py l. 245-250 It should work. I remember turning it off because, if I remember correctly, there was some interaction with the kvcache during generation that I had not time to debug.

I'm not putting it back in the repo before checking that everything is working correctly. I might have time for this during next week. I'll keep you posted. If you have time to look into this, let me know if it works on your side with flash attention activated.

ArnaudPannatier commented 2 months ago

I've done a bit of refactoring and testing (see: https://github.com/idiap/sigma-gpt/blob/main/text/tests/test_flash_attention.py) The problem is that during generation, scaled_dot_product_attention with is_causal=True use a different mask that is expected for the cache to be working.

That being said I think that scaled_dot_product_attention can be use safely during training. attn_mask could be also used to solve the kvcache issue but it would require a bit of benchmarking to know what speed up it give and which kernel it uses.

I've marked it as enhancement. I'll get to it when I've time, and I'll happy take pull requests I someone want to do it.