Open Fuzzkatt opened 4 months ago
FA uses (batch, seqlen, nheads, headdim). Torch sdpa expects (batch, nheads, seqlen, headdim).
Thanks for the clarification! I rewrote the minimal repro to use the correct dimensions, I'm still seeing some fairly minor mismatches:
E AssertionError: Tensor-likes are not close!
E
E Mismatched elements: 6 / 196608 (0.0%)
E Greatest absolute difference: 5.340576171875e-05 at index (3, 99, 4, 61) (up to 1e-05 allowed)
E Greatest relative difference: 0.015869140625 at index (3, 113, 4, 30) (up to 0.001 allowed)
Do you have any thoughts on what is the expected mismatch between torch sdpa and fa3? Just so I can adjust the allclose to the "recommended" level in my unit testing.
sdpa is probably just running FA2 :D
As always, you want to check against a reference implementation: (flashattention in bf16 - reference impl in fp32) vs (reference impl in bf16 - reference impl in fp32).
As always, you want to check against a reference implementation: (flashattention in bf16 - reference impl in fp32) vs (reference impl in bf16 - reference impl in fp32).
@tridao I found that the activation values of fa2 and fa3 can be bitwise aligned for the first 64 tokens, but the activation values of tokens beyond the first 64 tokens cannot be aligned.
There's no guarantee of bitwise identical results for two different implementations since floating point maths are not associative
In [1]: import torch
In [2]: a = torch.randn(10, dtype=torch.bfloat16, device='cuda')
In [3]: (a + 0.3 - 0.3 - a).abs().max().item()
Out[3]: 0.00390625
There's no guarantee of bitwise identical results for two different implementations since floating point maths are not associative
In [1]: import torch In [2]: a = torch.randn(10, dtype=torch.bfloat16, device='cuda') In [3]: (a + 0.3 - 0.3 - a).abs().max().item() Out[3]: 0.00390625
@tridao However, the attention activation values of the first 64 tokens can be bitwise aligned.
any resolution on this @Hijdk @tridao ?? i see numerical mismatch and broken inference results when i swap sdpa with FA-3.
I saw that the results of fa3's flash_attn_func and torch.nn.functional.scaled_dot_product_attention(). Given the following minimal repro:
I'm seeing
Just wanted to confirm if this is expected, and if so, why is fa3 expected to be so numerically different from torch native sdpa.