Dao-AILab / flash-attention

Fast and memory-efficient exact attention
BSD 3-Clause "New" or "Revised" License
13.35k stars 1.21k forks source link

Flash attention gives different results than reference attention #145

Closed dvruette closed 1 year ago

dvruette commented 1 year ago

Comparing with the reference self-attention implementation from the flash_attn module, I find that flash attention gives significantly different results:

import torch
from flash_attn.modules.mha import FlashSelfAttention, SelfAttention

flash_attn_module = FlashSelfAttention(causal=True)
self_attn_module = SelfAttention(causal=True)

batch_size = 4
n_heads = 8
seq_len = 10
head_size = 64

dtype = torch.float16
device = torch.device("cuda")

qkv = torch.randn(batch_size, seq_len, 3, n_heads, head_size, dtype=dtype, device=device)

out1 = flash_attn_module(qkv)
out2 = self_attn_module(qkv)

print((out1 - out2).abs().sum())
# >>> tensor(2.1094, device='cuda:0', dtype=torch.float16)

This is rather unexpected and can lead to accumulative error, giving vastly different results for a deep model depending on whether flash attention is used or not. Is this to be expected? Am I missing something?

Pytorch version: 1.13.1 Flash attention version: 0.2.8 CUDA version: 11.7

tridao commented 1 year ago

The right comparison is: (FlashAttention in fp16 - Standard attention in fp32) vs (Standard attention in fp16 - Standard attention in fp32).

These two errors are usually comparable.

Here I'm printing out the mean numerical error, instead of the sum (i.e. avg error per element):

out1 = flash_attn_module(qkv)
out2 = self_attn_module(qkv)
out_ref = self_attn_module(qkv.float())

print((out1 - out_ref).abs().mean())
# tensor(8.4682e-05, device='cuda:0')
print((out2 - out_ref).abs().mean())
# tensor(0.0001, device='cuda:0')

In this particular case FlashAttention in fp16 is more accurate than the standard implementation in fp16.

dvruette commented 1 year ago

Ah yes, this makes a lot of sense. My problem is that the error accumulates in a large model (i.e. GPTNeoX 20B) trained on a "less precise" attention implementation, leading to relative errors of > 1 and absolute errors of > 0.5 in the logits, depending on whether flash attention is used or not. The question is whether or not this is a good or a bad thing, but in my experiments it seems like the model performs worse (in terms of fine-tuning loss) with flash attention compared to the regular attention implementation. Any insight on how to remedy this?

tridao commented 1 year ago

One sanity check is that if you finetune with attention in fp32, does that also lead to worse results? In other words, must the attention implementation during finetuning match the attention implementation during pretraining?

If that's the case, then you should use standard attention in fp16 during finetuning. I don't think there's a way to make FlashAttention "less precise" to match standard attention in fp16 exactly. Floating point operations are not associative, any change in orders of operation will result in (small) numerical differences.

dvruette commented 1 year ago

That's a good point. I will see if I can manage to fit a fp32 run for comparison.

One thing worth noting is that xformers memory efficient implementation seems to yield the exact same result as flash attention (0.0 absolute error), so if anything it seems like the GPTNeoX attention implementation is non-standard.

tridao commented 1 year ago

One thing worth noting is that xformers memory efficient implementation seems to yield the exact same result as flash attention (0.0 absolute error).

That's because xformers runs the exact same code here. They compile FlashAttention as a submodule and dispatch to this implementation by default.

dvruette commented 1 year ago

It looks like fp32 also leads to bad results (even worse than flash attention for some reason), so it seems like GPTNeoX-20B simply isn't compatible with flash attention.

tridao commented 1 year ago

I don't have first hand experience here, though I would have expected softmax in fp32 (and FlashAttention) to work. If even softmax in fp32 doesn't work then sounds like you should stick to standard softmax in fp16. I'm closing the issue for now.