Open zaptrem opened 1 year ago
The right comparison is (FlashAttention in fp16/bf16 - standard attention in fp32) vs (standard attention in fp16/bf16 - standard attention in fp32).
What are the two differences ^^^ in your use case?
Sorry I'm not quite sure what you're hinting at. I'm comparing FlashAttention 2 FP16 vs Torch's Memory Efficient SDP Attention FP16.
I think for scaled_dot_product_attention
you need to transpose the second and the third axes of qkv. It requires a different layout.
query = torch.rand(1, 8, 16, 32, dtype=torch.float16, device="cuda")
key = torch.rand(1, 8, 16, 32, dtype=torch.float16, device="cuda")
value = torch.rand(1, 8, 16, 32, dtype=torch.float16, device="cuda")
is_causal = True
with torch.no_grad():
with torch.backends.cuda.sdp_kernel(enable_math=False):
ref = F.scaled_dot_product_attention(torch.permute(query, [0, 2, 1, 3]),
torch.permute(key, [0, 2, 1, 3]),
torch.permute(value, [0, 2, 1, 3]),
is_causal=is_causal)
ref = torch.permute(ref, [0, 2, 1, 3])
out = flash_attn_func(query, key, value, causal=is_causal)
print(torch.max(torch.abs(out - ref)), torch.mean(torch.abs(out - ref)))
@tridao I've checked the differences you suggest, and they are of the same order as you thought they would.
Still could you explain where the difference between FlashAttention in fp16 and standard attention in fp16 is coming from? Cheers
Floating point operations are not associative. Changing the order of the operations will change the output, up to numerical precision. Example
In [1]: import torch
In [2]: a = torch.randn(1024, dtype=torch.float16, device="cuda")
In [3]: out1 = a + 0.3 - 0.2
In [4]: out2 = a - 0.2 + 0.3
In [5]: (out1 - out2).abs().max()
Out[5]: tensor(0.0020, device='cuda:0', dtype=torch.float16)
import torch.nn.functional as F
B = 1; T = 3; nh = 32; C = 64
# q = torch.arange(B * T * nh * C).reshape([B, T, nh, C]).float();
# k = torch.arange(B * T * nh * C).flip(dims=(-1,)).reshape([B, T, nh, C]).float();
# v= torch.arange(B * T * nh * C).reshape([B, T, nh, C]).float() ** 1.5
q = torch.randn([B, T, nh, C])
k = torch.randn([B, T, nh, C])
v = torch.randn([B, T, nh, C])
out = F.scaled_dot_product_attention(q.permute([0, 2, 1, 3]).cuda(), k.permute([0, 2, 1, 3]).cuda(), v.permute([0, 2, 1, 3]).cuda(), dropout_p=0.0, is_causal=False).permute([0, 2, 1, 3])
import math
def func(q, k, v, use_causal=False):
q, k, v = q.permute([0, 2, 1, 3]), k.permute([0, 2, 3, 1]), v.permute([0, 2, 1, 3])
# attn = (q @ k.transpose([0, 1, 3, 2])) * (1.0 + math.sqrt(k.shape[-1]))
attn = (q @ k) * (1.0 / math.sqrt(k.shape[-1]))
if use_causal:
attn = masked_fill(attn, self.causal_mask[:, :, :T, :T] == 0, float('-inf'))
attn = F.softmax(attn, dim=-1)
out = attn @ v # (B, nh, T, T) x (B, nh, T, c) -> (B, nh, T, c)
out = out.permute([0, 2, 1, 3]) # (B, T, nh, c)
return out
out2 = func(q.cuda(), k.cuda(), v.cuda(), use_causal=False)
print((out2 - out).abs().mean()) # about 0.3
I use torch.float32 in unit test and find the difference between normal attention and flash attention is about 0.3. I think the precision is enough, but the difference is large. Could someone tell me why ?
With torch.float32 F.scaled_dot_product_attention does not call FlashAttention (only implemented for fp16 and bf16). You can ask in the Pytorch github.
should I be worried about the difference?
I swapped out the Torch attention function for Flash Attention 2 in the MusicGen project here like so:
but the resulting tensors were not equal (my assertion halted execution) during inference like so:
Additionally ignoring the above significant divergence which affects model output, I also run into the following issue during inference
Since num_heads stays at 1 but the k num_heads scales with the number of tokens in the context window. However, I assume this will be solved by the forthcoming inference optimizations mentioned here.