Dao-AILab / flash-attention

Fast and memory-efficient exact attention
BSD 3-Clause "New" or "Revised" License
14.49k stars 1.36k forks source link

Flash attn 3 has large numerical mismatches with torch spda #1128

Open Fuzzkatt opened 4 months ago

Fuzzkatt commented 4 months ago

I saw that the results of fa3's flash_attn_func and torch.nn.functional.scaled_dot_product_attention(). Given the following minimal repro:

import pytest
import torch

# flash attn 3
try:
    from flash_attn_interface import flash_attn_func                                                                                                                                     

    HAS_FA3 = True
except:
    HAS_FA3 = False

def test_fa3():
    if not HAS_FA3:
        pytest.skip("fa3 not built")

    batch = 4 
    seq_len = 128 
    num_heads = 6 
    dim_per_head = 64
    device = "cuda"
    dtype = torch.float16

    q = torch.randn([batch, seq_len, num_heads, dim_per_head], device="cuda", dtype=dtype, requires_grad=True)
    k = torch.randn([batch, seq_len, num_heads, dim_per_head], device="cuda", dtype=dtype, requires_grad=True)
    v = torch.randn([batch, seq_len, num_heads, dim_per_head], device="cuda", dtype=dtype, requires_grad=True)

    def fn(query, key, value):
        return torch.nn.functional.scaled_dot_product_attention(query, key, value)

    # Verifies the result is close to PyTorch
    out = flash_attn_func(q, k, v)[0]
    out_ref = fn(q, k, v)
    torch.testing.assert_close(out, out_ref)

I'm seeing

>       torch.testing.assert_close(out, out_ref)
E       AssertionError: Tensor-likes are not close!
E       
E       Mismatched elements: 196577 / 196608 (100.0%)
E       Greatest absolute difference: 3.39453125 at index (1, 28, 5, 49) (up to 1e-05 allowed)
E       Greatest relative difference: 23600.0 at index (2, 100, 3, 37) (up to 0.001 allowed)

accuracy_test.py:34: AssertionError

Just wanted to confirm if this is expected, and if so, why is fa3 expected to be so numerically different from torch native sdpa.

tridao commented 4 months ago

FA uses (batch, seqlen, nheads, headdim). Torch sdpa expects (batch, nheads, seqlen, headdim).

Fuzzkatt commented 4 months ago

Thanks for the clarification! I rewrote the minimal repro to use the correct dimensions, I'm still seeing some fairly minor mismatches:

E       AssertionError: Tensor-likes are not close!
E       
E       Mismatched elements: 6 / 196608 (0.0%)
E       Greatest absolute difference: 5.340576171875e-05 at index (3, 99, 4, 61) (up to 1e-05 allowed)
E       Greatest relative difference: 0.015869140625 at index (3, 113, 4, 30) (up to 0.001 allowed)

Do you have any thoughts on what is the expected mismatch between torch sdpa and fa3? Just so I can adjust the allclose to the "recommended" level in my unit testing.

tridao commented 4 months ago

sdpa is probably just running FA2 :D

tridao commented 4 months ago

As always, you want to check against a reference implementation: (flashattention in bf16 - reference impl in fp32) vs (reference impl in bf16 - reference impl in fp32).

Hijdk commented 3 months ago

As always, you want to check against a reference implementation: (flashattention in bf16 - reference impl in fp32) vs (reference impl in bf16 - reference impl in fp32).

@tridao I found that the activation values of fa2 and fa3 can be bitwise aligned for the first 64 tokens, but the activation values of tokens beyond the first 64 tokens cannot be aligned.

tridao commented 3 months ago

There's no guarantee of bitwise identical results for two different implementations since floating point maths are not associative

In [1]: import torch

In [2]: a = torch.randn(10, dtype=torch.bfloat16, device='cuda')

In [3]: (a + 0.3 - 0.3 - a).abs().max().item()
Out[3]: 0.00390625
Hijdk commented 3 months ago

There's no guarantee of bitwise identical results for two different implementations since floating point maths are not associative

In [1]: import torch

In [2]: a = torch.randn(10, dtype=torch.bfloat16, device='cuda')

In [3]: (a + 0.3 - 0.3 - a).abs().max().item()
Out[3]: 0.00390625

@tridao However, the attention activation values of the first 64 tokens can be bitwise aligned.

nighting0le01 commented 3 weeks ago

any resolution on this @Hijdk @tridao ?? i see numerical mismatch and broken inference results when i swap sdpa with FA-3.