Dao-AILab / flash-attention

Fast and memory-efficient exact attention
BSD 3-Clause "New" or "Revised" License
14.42k stars 1.35k forks source link

Output Discrepancy Between FlashAttention and PyTorch Attention #1359

Open pengzhangzhi opened 1 day ago

pengzhangzhi commented 1 day ago

I recently benchmarked FlashAttention against PyTorch’s scaled_dot_product_attention using a custom script and observed a significant discrepancy in the output. Below are the details of the issue:

Benchmark Summary:

•   Setup:
•   PyTorch Version:  2.5.0+cu121
•   FlashAttention Version:  2.7.0.post2
•   Device: NVIDIA GPU A100
•   Data Type: torch.float16
•   Observation:
•   Max Absolute Difference: 5.265625
•   Mean Absolute Difference: 0.7984185814857483

Steps to Reproduce:

Here is the benchmark code I used for testing:

import torch
import torch.nn.functional as F
import math

def PyTorchAttention(query_layer, key_layer, value_layer, attention_mask):
    """
    Attention implementation using PyTorch's scaled_dot_product_attention.
    """
    B, H, L, D = query_layer.shape
    # Compute attention using PyTorch's built-in function
    context_layer = F.scaled_dot_product_attention(
        query_layer, key_layer, value_layer,
        attn_mask=attention_mask,
        is_causal=False,
        scale=1,
    )
    # Rearrange and reshape the context layer
    context_layer = context_layer.reshape(B, L, H*D)
    return context_layer

def FlashAttention(query_layer, key_layer, value_layer, attention_mask):
    """
    Attention implementation using FlashAttention.
    """
    B, H, L, D = query_layer.shape
    qkv = torch.stack((query_layer, key_layer, value_layer), dim=1).reshape(B, L, 3, H, D)
    from flash_attn.bert_padding import pad_input, unpad_input
    from flash_attn import flash_attn_varlen_qkvpacked_func
    # Unpad the input sequences based on the attention mask
    qkv_unpadded, indices, cu_seqlens, max_seqlen, _ = unpad_input(
        hidden_states=qkv, attention_mask=attention_mask.reshape(B, L)
    )
    # Apply FlashAttention
    fa_out = flash_attn_varlen_qkvpacked_func(
        qkv_unpadded,
        cu_seqlens=cu_seqlens,
        max_seqlen=max_seqlen,
        softmax_scale=1.0,
        causal=False,
    )
    # Pad the output back to the original sequence length
    fa_out = pad_input(fa_out, indices, B, L).to(torch.float32)
    # Reshape and rearrange the output to match the expected shape
    fa_out = fa_out.reshape(B, L, H*D)

    return fa_out

def compare_attention_implementations(batch_size, num_heads, seq_length, head_dim, dtype=torch.float16):
    """
    Generates random inputs and compares the outputs of the two attention implementations.
    """
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Generate random inputs
    query_layer = torch.randn(batch_size, num_heads, seq_length, head_dim, dtype=dtype, device=device, requires_grad=False)
    key_layer = torch.randn(batch_size, num_heads, seq_length, head_dim, dtype=dtype, device=device, requires_grad=False)
    value_layer = torch.randn(batch_size, num_heads, seq_length, head_dim, dtype=dtype, device=device, requires_grad=False)

    # Generate a random attention mask
    attention_mask = torch.randint(0, 2, (batch_size, seq_length), dtype=torch.bool, device=device)

    attn_mask = attention_mask[:, None, None, :]  # [batch_size, 1, 1, seq_length]

    # Compute outputs from both implementations
    output1 = PyTorchAttention(query_layer, key_layer, value_layer, attn_mask)
    output2 = FlashAttention(query_layer, key_layer, value_layer, attention_mask)

    # Compare outputs
    difference = torch.abs(output1 - output2)
    max_diff = difference.max().item()
    mean_diff = difference.mean().item()

    print(f'Data type: {dtype}')
    print(f'Max absolute difference: {max_diff}')
    print(f'Mean absolute difference: {mean_diff}')

def seed_all(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
if __name__ == '__main__':
    # Set parameters
    batch_size = 2
    num_heads = 4
    seq_length = 128
    head_dim = 64  # Dimension per head
    dtype = torch.float16  # Change to torch.float32 or torch.float64 as needed
    seed_all(42)
    # Compare implementations
    from torch.nn.attention import sdpa_kernel, SDPBackend
    with sdpa_kernel(backends=[SDPBackend.MATH]):
        compare_attention_implementations(batch_size, num_heads, seq_length, head_dim, dtype)

The discrepancy is consistent across multiple runs, even with deterministic settings (e.g., fixed seeds).

Questions and Notes:

1.  Is this discrepancy expected under certain conditions (e.g., data type torch.float16, attention mask handling)?
2.  If not, could this indicate a potential bug or precision issue in the my benchmark implementation?

Thank you for your time and for developing such an efficient attention mechanism! Please let me know if you need further details or additional benchmarks.

Best regards,

tridao commented 1 day ago

This reshape is not what you want

qkv = torch.stack((query_layer, key_layer, value_layer), dim=1).reshape(B, L, 3, H, D)

Please see the docstring or tests here (https://github.com/Dao-AILab/flash-attention/blob/7153673c1a3c7753c38e4c10ef2c98a02be5f778/tests/test_flash_attn.py#L586) to see what input layout the function expects.