Dao-AILab / flash-attention

Fast and memory-efficient exact attention
BSD 3-Clause "New" or "Revised" License
14.41k stars 1.35k forks source link

Flash Attention 2 Output not Equal to PyTorch scaled_dot_product_attention in MusicGen Inference #383

Open zaptrem opened 1 year ago

zaptrem commented 1 year ago

I swapped out the Torch attention function for Flash Attention 2 in the MusicGen project here like so:

if self.memory_efficient:
    p = self.dropout if self.training else 0
    if _efficient_attention_backend == 'torch':
            
        x = torch.nn.functional.scaled_dot_product_attention(
        q, k, v, is_causal=attn_mask is not None, dropout_p=p)
        y = flash_attn_func(q, k, v, causal=attn_mask is not None, dropout_p=p)

        assert torch.allclose(x, y, atol=1e-5), "flash_attn_func and scaled_dot_product_attention are not equal"
    else:
        x = ops.memory_efficient_attention(q, k, v, attn_mask, p=p)

but the resulting tensors were not equal (my assertion halted execution) during inference like so:

import torchaudio
from audiocraft.models import MusicGen
from audiocraft.data.audio import audio_write

model = MusicGen.get_pretrained('melody')
model.set_generation_params(duration=8)  # generate 8 seconds.
descriptions = ['happy rock', 'energetic EDM', 'sad jazz']
wav = model.generate(descriptions)  # generates 3 samples.

for idx, one_wav in enumerate(wav):
    # Will save under {idx}.wav, with loudness normalization at -14 db LUFS.
    audio_write(f'{idx}', one_wav.cpu(), model.sample_rate, strategy="loudness", loudness_compressor=True)

Additionally ignoring the above significant divergence which affects model output, I also run into the following issue during inference

Number of heads in key/value must divide number of heads in query

Since num_heads stays at 1 but the k num_heads scales with the number of tokens in the context window. However, I assume this will be solved by the forthcoming inference optimizations mentioned here.

tridao commented 1 year ago

The right comparison is (FlashAttention in fp16/bf16 - standard attention in fp32) vs (standard attention in fp16/bf16 - standard attention in fp32).

What are the two differences ^^^ in your use case?

zaptrem commented 1 year ago

Sorry I'm not quite sure what you're hinting at. I'm comparing FlashAttention 2 FP16 vs Torch's Memory Efficient SDP Attention FP16.

masahi commented 1 year ago

I think for scaled_dot_product_attention you need to transpose the second and the third axes of qkv. It requires a different layout.

query = torch.rand(1, 8, 16, 32, dtype=torch.float16, device="cuda")
key = torch.rand(1, 8, 16, 32, dtype=torch.float16, device="cuda")
value = torch.rand(1, 8, 16, 32, dtype=torch.float16, device="cuda")

is_causal = True

with torch.no_grad():
    with torch.backends.cuda.sdp_kernel(enable_math=False):
        ref = F.scaled_dot_product_attention(torch.permute(query, [0, 2, 1, 3]),
                                             torch.permute(key, [0, 2, 1, 3]),
                                             torch.permute(value, [0, 2, 1, 3]),
                                             is_causal=is_causal)
        ref = torch.permute(ref, [0, 2, 1, 3])

    out = flash_attn_func(query, key, value, causal=is_causal)

print(torch.max(torch.abs(out - ref)), torch.mean(torch.abs(out - ref)))
rems75 commented 1 year ago

@tridao I've checked the differences you suggest, and they are of the same order as you thought they would.

Still could you explain where the difference between FlashAttention in fp16 and standard attention in fp16 is coming from? Cheers

tridao commented 1 year ago

Floating point operations are not associative. Changing the order of the operations will change the output, up to numerical precision. Example

In [1]: import torch

In [2]: a = torch.randn(1024, dtype=torch.float16, device="cuda")

In [3]: out1 = a + 0.3 - 0.2

In [4]: out2 = a - 0.2 + 0.3

In [5]: (out1 - out2).abs().max()
Out[5]: tensor(0.0020, device='cuda:0', dtype=torch.float16)
kaixinbear commented 10 months ago
import torch.nn.functional as F

B = 1; T = 3; nh = 32; C = 64

# q = torch.arange(B * T *  nh * C).reshape([B, T, nh, C]).float(); 
# k = torch.arange(B * T *  nh * C).flip(dims=(-1,)).reshape([B, T, nh, C]).float(); 
# v= torch.arange(B * T *  nh * C).reshape([B, T, nh, C]).float() ** 1.5
q = torch.randn([B, T, nh, C])
k = torch.randn([B, T, nh, C])
v = torch.randn([B, T, nh, C])

out = F.scaled_dot_product_attention(q.permute([0, 2, 1, 3]).cuda(), k.permute([0, 2, 1, 3]).cuda(), v.permute([0, 2, 1, 3]).cuda(), dropout_p=0.0, is_causal=False).permute([0, 2, 1, 3])

import math
def func(q, k, v, use_causal=False):
    q, k, v = q.permute([0, 2, 1, 3]), k.permute([0, 2, 3, 1]), v.permute([0, 2, 1, 3])
    # attn = (q @ k.transpose([0, 1, 3, 2])) * (1.0 + math.sqrt(k.shape[-1]))
    attn = (q @ k) * (1.0 / math.sqrt(k.shape[-1]))
    if use_causal:
        attn = masked_fill(attn, self.causal_mask[:, :, :T, :T] == 0, float('-inf'))
    attn = F.softmax(attn, dim=-1)
    out = attn @ v # (B, nh, T, T) x (B, nh, T, c) -> (B, nh, T, c)
    out = out.permute([0, 2, 1, 3]) # (B, T, nh, c)
    return out

out2 = func(q.cuda(), k.cuda(), v.cuda(), use_causal=False)

print((out2 - out).abs().mean()) # about 0.3

image I use torch.float32 in unit test and find the difference between normal attention and flash attention is about 0.3. I think the precision is enough, but the difference is large. Could someone tell me why ?

tridao commented 10 months ago

With torch.float32 F.scaled_dot_product_attention does not call FlashAttention (only implemented for fp16 and bf16). You can ask in the Pytorch github.

chenhuiapp commented 6 months ago

should I be worried about the difference?