NVIDIA / TransformerEngine

A library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper and Ada GPUs, to provide better performance with lower memory utilization in both training and inference.
Apache License 2.0
1.84k stars 310 forks source link

[PyTorch] fused CUDNN attention kernel and sliding window attention #1197

Open Marks101 opened 5 days ago

Marks101 commented 5 days ago

Hello team,

we have been noticing some pretty large deviations between the attention output of flash/unfused attention versus the fused attention kernels when sliding window attention is active. The following sample illustrates this:

import torch

from transformer_engine.pytorch.attention import FlashAttention, FusedAttention, UnfusedDotProductAttention, get_swa_mask
from transformer_engine.common.recipe import DelayedScaling
import transformer_engine_torch as tex

window_size = (1024, 0)
seqlen, num_heads, kv_channels = 2048, 64, 64

q, k, v = [torch.randn(seqlen, 1, num_heads, kv_channels, dtype=torch.float16, device="cuda") for _ in range(3)]

flash_attn = FlashAttention(1.0)
fused_attn = FusedAttention(1.0)
unfused_attn = UnfusedDotProductAttention(1.0)

output_flash = flash_attn(q, k, v, "sbhd_sbhd_sbhd", window_size=window_size)
output_fused = fused_attn(q, k, v, "sbhd_sbhd_sbhd",

attention_mask = torch.ones(1, 1, seqlen, seqlen, dtype=torch.bool, device="cuda")
attn_mask_type, attention_mask = get_swa_mask(window_size, seqlen, seqlen, "causal", attention_mask)
output_unfused = unfused_attn(q, k, v, attn_mask_type=attn_mask_type, attention_mask=attention_mask)

print("diff flash vs unfused:", torch.max(torch.abs(output_flash - output_unfused)).item())
print("diff fused vs unfused:", torch.max(torch.abs(output_fused - output_unfused)).item())

The output we see on H100 and CUDA 12.5 with CUDNN 9.2.1 is:

diff flash vs unfused: 0.03076171875
diff fused vs unfused: 4.8828125

The later one seems rather large. Can you reproduce these results?

ksivaman commented 5 days ago

@cyanguwa Do you know what could be causing this?

cyanguwa commented 21 hours ago

Hi @Marks101 ,

Thanks for raising this issue. I seem to have overlooked the different window_size definition in cuDNN. cuDNN supports sliding window (i - window_size_left, i], exclusive of the i - window_size_left element, whereas the original paper, flash-attn and TE unfused DPA have used the definition of [i - window_size_left, i + window_size_right], which is inclusive of the boundary elements. Please give #1212 a try and let me know if there's still any issues. Thanks!


diff flash vs unfused: 0.0330810546875
diff fused vs unfused: 0.033203125
diff flash vs   fused: 0.001953125