Open Marks101 opened 1 month ago
Hey Markus!
I think what you wanted to do is in line with the thd
format. I've tweaked your script a little and it seems to work for both FlashAttention
and FusedAttention
. In this case, Transformer Engine is treating your batch as [t=2048, h, d]
, and batch size b=3
(inferred from cu_seqlens
's shape [b+1]
). Your original script did run, for FlashAttention
, but it wasn't running as intended if I understand your use case correctly. If you turn on NVTE_DEBUG_LEVEL=2
, you'll see that it's treating the batch as sbhd
format, i.e. all 2048 tokens were in 1 sequence. It's also applying causal
mask (because that's the default), and not using cu_seqlens
tensors (because we go through the flash_attn_func path, which doesn't take cu_seqlens
).
I have a little blurb over here to explain the use cases of sbhd + padding
and thd + padding
. Hope that helps!
import os
import torch
from transformer_engine.pytorch.attention import DotProductAttention, _attention_backends
seqlen, batch_size, heads, kv_channels = 2048, 1, 16, 64
q, k, v = [torch.randn(seqlen * batch_size, heads, kv_channels, dtype=torch.float16, device="cuda", requires_grad=True) for _ in range(3)]
cu_seqlens_q = cu_seqlens_kv = torch.tensor([0, 300, 1100, 2048], device="cuda", dtype=torch.int32)
attention_kernel = DotProductAttention(heads, kv_channels)
os.environ["NVTE_FUSED_ATTN"] = "0"
os.environ["NVTE_FLASH_ATTN"] = "1"
output_flash = attention_kernel(q, k, v, qkv_format='thd', attn_mask_type='padding', cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv)
os.environ["NVTE_FUSED_ATTN"] = "1"
os.environ["NVTE_FLASH_ATTN"] = "0"
_attention_backends["backend_selection_requires_update"] = True
output_fused = attention_kernel(q, k, v, qkv_format='thd', attn_mask_type='padding', cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv)
torch.testing.assert_close(output_fused, output_flash, atol=1e-2, rtol=1e-2)
Run:
NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 python test.py
/code/pr-thd-int64/TransformerEngine/transformer_engine/pytorch/attention.py:5162: UserWarning: window_size should be (-1, -1) or (>=0, >=0) for attn_mask_type=padding
warnings.warn(
[INFO | DotProductAttention]: Running with FlashAttention backend (version 2.4.2)
[INFO | DotProductAttention]: Running with FusedAttention backend (sub-backend 1)
Hey Hey,
oh, okay, then my use case was simply not correct 🙈 thank you so much for your explanation and for extending the documentation. That was exactly the information that I was looking for!
Hi team,
we are currently adapting our training environment to use the fused attention functions. In one of our training setups, we work with batch size one and concaternate multiple documents along the sequence dimension (
sbhd
format). We setcu_seqlens_q
andcu_seqlens_kv
so that these documents cannot attend on each other. This is actually not apadding
use case, because we always fill up the whole sequence and there is no packing and unpacking withpack_tensors()
andunpack_tensors()
required. With the flash attention backend this worked perfectly fine and produces the results that we intended. With the fused attention functions we get device side assertions for this input. Here is a small sample code:Was the use case we have been working with ever intended? Or is there just some assertion missing that forbids to use
cu_seqlens
without setting apadding
mode?