Dao-AILab / flash-attention

Fast and memory-efficient exact attention
BSD 3-Clause "New" or "Revised" License
13.26k stars 1.2k forks source link

BF16 Flash Attention producing incorrect values compared to FP16 Flash Attention on A100 #1071

Open JerrickLiu opened 1 month ago

JerrickLiu commented 1 month ago

Repro

import flash_attn
import torch
from einops import rearrange

def snr(a: torch.Tensor, b: torch.Tensor):
    if torch.equal(a, b):
        return float("inf")

    if a.dtype == torch.bool and b.dtype == torch.bool:
        return torch.inf if torch.equal(a, b) else 0.0
    diff = a - b
    l2 = torch.sum(torch.pow(diff, 2)) / torch.sum(torch.pow(b, 2))
    snr_val = 10 * torch.log10(1 / l2)
    return snr_val.item()

batch_size = 2
seq_len = 256
num_heads = 2
embed_dim = 64

Q = torch.randn(batch_size, num_heads, seq_len, embed_dim, dtype=torch.float16, device="cuda")
K = torch.randn(batch_size, num_heads, seq_len, embed_dim, dtype=torch.float16, device="cuda")
V = torch.randn(batch_size, num_heads, seq_len, embed_dim, dtype=torch.float16, device="cuda")

scale = 1 / torch.sqrt(torch.tensor(embed_dim, dtype=Q.dtype))
s = Q @ K.transpose(-2, -1) * scale
s_weights = torch.softmax(s, dim=-1)
out = s_weights @ V

rearrange_pattern = "b n s d -> b s n d"
Q = rearrange(Q, rearrange_pattern)
K = rearrange(K, rearrange_pattern)
V = rearrange(V, rearrange_pattern)

out_flash = flash_attn.flash_attn_func(Q, K, V, softmax_scale=scale)
out_flash = rearrange(out_flash, "b s n d -> b n s d")

print("SNR: ", snr(out_flash, out))
print("Flash Attention matches normal: ", torch.allclose(out, out_flash, atol=1e-3))

This prints out infinite SNR and both out and out_flash are allclose.

SNR:  inf
Flash Attention matches normal:  True

However, when changing the dtype to torch.bfloat16, there is like a SNR of 45 and the tensors do not match.

SNR:  46.5
Flash Attention matches normal:  False

Running on 80 GB SXM A100

Thu Jul 18 14:02:40 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.12             Driver Version: 535.104.12   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA A100-SXM4-80GB          On  | 00000000:07:00.0 Off |                    0 |
| N/A   32C    P0              63W / 400W |      2MiB / 81920MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+

+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|  No running processes found                                                           |
+---------------------------------------------------------------------------------------+
tianleiwu commented 1 month ago

BF16 has lower precision than FP16. Normally I use 0.001 as absolute tolerance for float16, while 0.01 for bfloat16.

>>> import torch
>>> x = torch.rand(1000, 1000)
>>> ((x + 1e-6).to(torch.bfloat16) - x.to(torch.bfloat16)).abs().max()
tensor(0.0039, dtype=torch.bfloat16)
>>> ((x + 1e-6).to(torch.float16) - x.to(torch.float16)).abs().max()
tensor(0.0005, dtype=torch.float16)
JerrickLiu commented 1 month ago

Right, still hitting some mismatches even when setting atol=1e-2

tridao commented 1 month ago

Please compare (flashattn in bf16 - reference attn in fp32) vs (reference attn in bf16 - reference attn in fp32)