NVIDIA / TransformerEngine

A library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper and Ada GPUs, to provide better performance with lower memory utilization in both training and inference.
https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/index.html
Apache License 2.0
1.84k stars 310 forks source link

[PyTorch] fused CUDNN attention kernel and sliding window attention #1197

Open Marks101 opened 5 days ago

Marks101 commented 5 days ago

Hello team,

we have been noticing some pretty large deviations between the attention output of flash/unfused attention versus the fused attention kernels when sliding window attention is active. The following sample illustrates this:

import torch

from transformer_engine.pytorch.attention import FlashAttention, FusedAttention, UnfusedDotProductAttention, get_swa_mask
from transformer_engine.common.recipe import DelayedScaling
import transformer_engine_torch as tex

window_size = (1024, 0)
seqlen, num_heads, kv_channels = 2048, 64, 64

q, k, v = [torch.randn(seqlen, 1, num_heads, kv_channels, dtype=torch.float16, device="cuda") for _ in range(3)]

flash_attn = FlashAttention(1.0)
fused_attn = FusedAttention(1.0)
unfused_attn = UnfusedDotProductAttention(1.0)

output_flash = flash_attn(q, k, v, "sbhd_sbhd_sbhd", window_size=window_size)
output_fused = fused_attn(q, k, v, "sbhd_sbhd_sbhd",
                          fused_attention_backend=tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
                          window_size=window_size,
                          fp8_meta=dict(recipe=DelayedScaling()))

attention_mask = torch.ones(1, 1, seqlen, seqlen, dtype=torch.bool, device="cuda")
attn_mask_type, attention_mask = get_swa_mask(window_size, seqlen, seqlen, "causal", attention_mask)
output_unfused = unfused_attn(q, k, v, attn_mask_type=attn_mask_type, attention_mask=attention_mask)

print("diff flash vs unfused:", torch.max(torch.abs(output_flash - output_unfused)).item())
print("diff fused vs unfused:", torch.max(torch.abs(output_fused - output_unfused)).item())

The output we see on H100 and CUDA 12.5 with CUDNN 9.2.1 is:

diff flash vs unfused: 0.03076171875
diff fused vs unfused: 4.8828125

The later one seems rather large. Can you reproduce these results?

ksivaman commented 5 days ago

@cyanguwa Do you know what could be causing this?

cyanguwa commented 21 hours ago

Hi @Marks101 ,

Thanks for raising this issue. I seem to have overlooked the different window_size definition in cuDNN. cuDNN supports sliding window (i - window_size_left, i], exclusive of the i - window_size_left element, whereas the original paper, flash-attn and TE unfused DPA have used the definition of [i - window_size_left, i + window_size_right], which is inclusive of the boundary elements. Please give #1212 a try and let me know if there's still any issues. Thanks!

Results:

diff flash vs unfused: 0.0330810546875
diff fused vs unfused: 0.033203125
diff flash vs   fused: 0.001953125