Closed Marks101 closed 1 month ago
I think this issue is caused by that we're doing stride checks in get_qkv_layout()
in DotProductAttention
. https://github.com/NVIDIA/TransformerEngine/blob/a68acd71d3500d41b2d75caf249a1a4f4445ba69/transformer_engine/pytorch/attention.py#L4675
If you directly calls FusedAttention
, it will not do such checks, while FA does them in https://github.com/Dao-AILab/flash-attention/blob/53a4f341634fcbc96bb999a3c804c192ea14f2ea/flash_attn/flash_attn_interface.py#L90.
cc @cyanguwa
I just had a second look at the q, k, v tensor strides in my example. The strides are (1152, 1152, 64, 1)
for all tensors. So at least the requirement that the last stride is 1 is fulfilled 😉
I added the get_qkv_layout()
to my sample but this did not fail and returned the sbhd_sbhd_sbhd
layout
The requirement for the last dimension's stride to be 1 seems to be fine here. But I think the problem is that the provided q
, k
, v
tensors do not have a sbhd_sbhd_sbhd
layout. It's more like a sb(h+2h_g)d
layout :). The .contiguous()
call will force it to sbhd_sbhd_sbhd
, which is why the test will pass then.
The DotProductAttention
module calls get_qkv_layout()
to run some checks on user inputs and convert them if necessary (in some limited capacity), but the FlashAttention
and FusedAttention
modules don't do those checks.
The get_qkv_layout()
call returns sbhd_sbhd_sbhd
because it missed something in the last elseif
logic. I've created a PR #1214 to improve the logic. For this specific case, the function will run run_iteratively()
twice and force a .contiguous()
on them before passing to FlashAttention
or FusedAttention
.
import os
import torch
from transformer_engine.pytorch.attention import FlashAttention, FusedAttention, DotProductAttention, _attention_backends
from transformer_engine.common.recipe import DelayedScaling
import transformer_engine_torch as tex
seqlen, q_heads, kv_heads, kv_channels = 2048, 16, 1, 64
seqlen_kv = 1024
qkv = torch.randn(seqlen, 1, q_heads + 2 * kv_heads, kv_channels, dtype=torch.float16, device="cuda")
q, k, v = qkv.split([q_heads, kv_heads, kv_heads], dim=2)
#q, k, v = [t.contiguous() for t in (q, k, v)]
flash_attn = DotProductAttention(q_heads, kv_channels=kv_channels, num_gqa_groups=kv_heads, softmax_scale=0.1)
fused_attn = DotProductAttention(q_heads, kv_channels=kv_channels, num_gqa_groups=kv_heads, softmax_scale=0.1)
os.environ["NVTE_FUSED_ATTN"] = "0"
os.environ["NVTE_FLASH_ATTN"] = "1"
output_flash = flash_attn(q, k, v)
os.environ["NVTE_FUSED_ATTN"] = "1"
os.environ["NVTE_FLASH_ATTN"] = "0"
_attention_backends["backend_selection_requires_update"] = True
output_fused = fused_attn(q, k, v)
print("diff:", torch.max(torch.abs(output_fused - output_flash)).item())
Results:
NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 python test.py
[INFO | DotProductAttention]: Running with FlashAttention backend (version 2.4.2)
[INFO | DotProductAttention]: Running with FusedAttention backend (sub-backend 1)
diff: 0.0009765625
Great, thanks for the solution and the detailed explanations. I just tested it and it works in our training code now as well 😃
Hello team,
we noticed discrepencies when using the
transformer_engine.pytorch.TransformerLayer
in combination with fused attention kernels and multi/group-query attention,fuse_qkv_params
andqkv_weight_interleaved
. All in all, this problem boils down to the following code snippet:On H100 with CUDA 12.5 and CUDNN 9.2.1 we get a max error of 6.1953125. If I uncomment the line that applies
contiguous()
the error is down to 0.001953125, accordingly I suspect that the issue is related to the handling of strides in the fused attention backend.Could you please have a look at this and try to reproduce my results? Thanks!