zhuzilin / ring-flash-attention

Ring attention implementation with flash attention
MIT License
585 stars 46 forks source link

Numerical errors in backward #42

Open grimulkan opened 4 months ago

grimulkan commented 4 months ago

Were you able to find out the reason for the small numerical errors in backward pass with ring flash attention?

I found the errors increase as you increase the world size, so it does seem to be related to the fact that flash attention returns 16-bit tensors, and even though we accumulate in a 32-bit buffer it seems it is not enough.

Maybe it is an easy PR in flash attention to have them return raw fp32, or do the accumulation upstream?

takfate commented 1 month ago

Do you have any idea how to implement it?

zhuzilin commented 1 month ago

tbh, not really... as far as I can tell, it won't be a simple PR with few changes to flash attention.

grimulkan commented 1 month ago

With Llama 405B there are many layers, and with ring sizes of 4 or 8 the numerical errors become catastrophic in backward. The errors actually originate in the forward pass in out & lse, but the backward grads can blow up, and it is deceptive because the forward loss value can look quite reasonable.

A quick workaround is to use the context parallel (llama3) implementation in this repo for the forward pass alone. It is much more numerically stable, and if you have NVLink is quite communication efficient (if you don't, there's a penalty, but you could try overlapping compute & comm over the head strides).

The downsides in this repo's implementation is that it doesn't support zigzag and doesn't have an explicit non-varlen implementation, but those are actually easy to address.

The memory overhead is not that bad: you no longer need the 32-bit buffers used in the ring forward, and you could try to minimize the overhead with head stride = 1.

Backward pass can remain the usual zigzag ring implementation. This combination allowed me to scale the ring size even with very large models without huge numerical errors.

An alternative solution would be to replace the forward pass with this Triton version which supports rings and basically does what this issue wants done (no external sigmoid accumulation): https://github.com/lucidrains/ring-attention-pytorch/blob/main/ring_attention_pytorch/triton_flash_attn.py But some modifications would be needed, and this version only supports striped attention (which is a bit worse than this repo's zigzag implementation), and I am not sure it actually takes advantage of GQA (just replicates KV heads in VRAM).

zhuzilin commented 4 weeks ago

Thank you for this suggestion, I'll try to implement the llama3 implementation with zigzag when I have a moment :).