NVIDIA / TransformerEngine

A library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper and Ada GPUs, to provide better performance with lower memory utilization in both training and inference.
https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/index.html
Apache License 2.0
1.97k stars 327 forks source link

[PyTorch] fused CUDNN attention kernel not properly handling strides #1195

Closed Marks101 closed 1 month ago

Marks101 commented 1 month ago

Hello team,

we noticed discrepencies when using the transformer_engine.pytorch.TransformerLayer in combination with fused attention kernels and multi/group-query attention, fuse_qkv_params and qkv_weight_interleaved. All in all, this problem boils down to the following code snippet:

import torch

from transformer_engine.pytorch.attention import FlashAttention, FusedAttention
from transformer_engine.common.recipe import DelayedScaling
import transformer_engine_torch as tex

seqlen, q_heads, kv_heads, kv_channels = 2048, 16, 1, 64

qkv = torch.randn(seqlen, 1, q_heads + 2 * kv_heads, kv_channels, dtype=torch.float16, device="cuda")

q, k, v = qkv.split([q_heads, kv_heads, kv_heads], dim=2)
#q, k, v = [t.contiguous() for t in (q, k, v)]

flash_attn = FlashAttention(1.0)
fused_attn = FusedAttention(1.0)

output_flash = flash_attn(q, k, v, "sbhd_sbhd_sbhd", window_size=(-1, 0))
output_fused = fused_attn(q, k, v, "sbhd_sbhd_sbhd",
                          fused_attention_backend=tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
                          window_size=(-1, 0),
                          fp8_meta=dict(recipe=DelayedScaling()))

print("diff:", torch.max(torch.abs(output_fused - output_flash)).item())

On H100 with CUDA 12.5 and CUDNN 9.2.1 we get a max error of 6.1953125. If I uncomment the line that applies contiguous() the error is down to 0.001953125, accordingly I suspect that the issue is related to the handling of strides in the fused attention backend.

Could you please have a look at this and try to reproduce my results? Thanks!

yaox12 commented 1 month ago

I think this issue is caused by that we're doing stride checks in get_qkv_layout() in DotProductAttention. https://github.com/NVIDIA/TransformerEngine/blob/a68acd71d3500d41b2d75caf249a1a4f4445ba69/transformer_engine/pytorch/attention.py#L4675 If you directly calls FusedAttention, it will not do such checks, while FA does them in https://github.com/Dao-AILab/flash-attention/blob/53a4f341634fcbc96bb999a3c804c192ea14f2ea/flash_attn/flash_attn_interface.py#L90.

cc @cyanguwa

Marks101 commented 1 month ago

I just had a second look at the q, k, v tensor strides in my example. The strides are (1152, 1152, 64, 1) for all tensors. So at least the requirement that the last stride is 1 is fulfilled 😉

I added the get_qkv_layout() to my sample but this did not fail and returned the sbhd_sbhd_sbhd layout

cyanguwa commented 1 month ago

The requirement for the last dimension's stride to be 1 seems to be fine here. But I think the problem is that the provided q, k, v tensors do not have a sbhd_sbhd_sbhd layout. It's more like a sb(h+2h_g)d layout :). The .contiguous() call will force it to sbhd_sbhd_sbhd, which is why the test will pass then.

The DotProductAttention module calls get_qkv_layout() to run some checks on user inputs and convert them if necessary (in some limited capacity), but the FlashAttention and FusedAttention modules don't do those checks.

The get_qkv_layout() call returns sbhd_sbhd_sbhd because it missed something in the last elseif logic. I've created a PR #1214 to improve the logic. For this specific case, the function will run run_iteratively() twice and force a .contiguous() on them before passing to FlashAttention or FusedAttention.

import os
import torch

from transformer_engine.pytorch.attention import FlashAttention, FusedAttention, DotProductAttention, _attention_backends
from transformer_engine.common.recipe import DelayedScaling
import transformer_engine_torch as tex

seqlen, q_heads, kv_heads, kv_channels = 2048, 16, 1, 64
seqlen_kv = 1024

qkv = torch.randn(seqlen, 1, q_heads + 2 * kv_heads, kv_channels, dtype=torch.float16, device="cuda")

q, k, v = qkv.split([q_heads, kv_heads, kv_heads], dim=2)
#q, k, v = [t.contiguous() for t in (q, k, v)]

flash_attn = DotProductAttention(q_heads, kv_channels=kv_channels, num_gqa_groups=kv_heads, softmax_scale=0.1)
fused_attn = DotProductAttention(q_heads, kv_channels=kv_channels, num_gqa_groups=kv_heads, softmax_scale=0.1)

os.environ["NVTE_FUSED_ATTN"] = "0"
os.environ["NVTE_FLASH_ATTN"] = "1"
output_flash = flash_attn(q, k, v)

os.environ["NVTE_FUSED_ATTN"] = "1"
os.environ["NVTE_FLASH_ATTN"] = "0"
_attention_backends["backend_selection_requires_update"] = True
output_fused = fused_attn(q, k, v)

print("diff:", torch.max(torch.abs(output_fused - output_flash)).item())

Results:

NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 python test.py
[INFO     | DotProductAttention]: Running with FlashAttention backend (version 2.4.2)
[INFO     | DotProductAttention]: Running with FusedAttention backend (sub-backend 1)
diff: 0.0009765625
Marks101 commented 1 month ago

Great, thanks for the solution and the detailed explanations. I just tested it and it works in our training code now as well 😃