NVIDIA / TransformerEngine

A library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper and Ada GPUs, to provide better performance with lower memory utilization in both training and inference.
https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/index.html
Apache License 2.0
2k stars 333 forks source link

[PyTorch] fused attention and cu_seqlens #1259

Open Marks101 opened 1 month ago

Marks101 commented 1 month ago

Hi team,

we are currently adapting our training environment to use the fused attention functions. In one of our training setups, we work with batch size one and concaternate multiple documents along the sequence dimension (sbhd format). We set cu_seqlens_q and cu_seqlens_kv so that these documents cannot attend on each other. This is actually not a padding use case, because we always fill up the whole sequence and there is no packing and unpacking with pack_tensors() and unpack_tensors() required. With the flash attention backend this worked perfectly fine and produces the results that we intended. With the fused attention functions we get device side assertions for this input. Here is a small sample code:

import os
import torch

from transformer_engine.pytorch.attention import DotProductAttention, _attention_backends

seqlen, batch_size, heads, kv_channels = 2048, 1, 16, 64

q, k, v = [torch.randn(seqlen, batch_size, heads, kv_channels, dtype=torch.float16, device="cuda", requires_grad=True) for _ in range(3)]

cu_seqlens_q = cu_seqlens_kv = torch.tensor([0, 300, 1100, 2048], device="cuda", dtype=torch.int32)

attention_kernel = DotProductAttention(heads, kv_channels)

os.environ["NVTE_FUSED_ATTN"] = "0"
os.environ["NVTE_FLASH_ATTN"] = "1"
output_flash = attention_kernel(q, k, v, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv)

os.environ["NVTE_FUSED_ATTN"] = "1"
os.environ["NVTE_FLASH_ATTN"] = "0"
_attention_backends["backend_selection_requires_update"] = True
output_fused = attention_kernel(q, k, v, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv)

torch.testing.assert_close(output_fused, output_flash, atol=1e-2, rtol=1e-2)

Was the use case we have been working with ever intended? Or is there just some assertion missing that forbids to use cu_seqlens without setting a padding mode?

cyanguwa commented 4 weeks ago

Hey Markus!

I think what you wanted to do is in line with the thd format. I've tweaked your script a little and it seems to work for both FlashAttention and FusedAttention. In this case, Transformer Engine is treating your batch as [t=2048, h, d], and batch size b=3 (inferred from cu_seqlens's shape [b+1]). Your original script did run, for FlashAttention, but it wasn't running as intended if I understand your use case correctly. If you turn on NVTE_DEBUG_LEVEL=2, you'll see that it's treating the batch as sbhd format, i.e. all 2048 tokens were in 1 sequence. It's also applying causal mask (because that's the default), and not using cu_seqlens tensors (because we go through the flash_attn_func path, which doesn't take cu_seqlens).

I have a little blurb over here to explain the use cases of sbhd + padding and thd + padding. Hope that helps!

import os
import torch

from transformer_engine.pytorch.attention import DotProductAttention, _attention_backends

seqlen, batch_size, heads, kv_channels = 2048, 1, 16, 64

q, k, v = [torch.randn(seqlen * batch_size, heads, kv_channels, dtype=torch.float16, device="cuda", requires_grad=True) for _ in range(3)]

cu_seqlens_q = cu_seqlens_kv = torch.tensor([0, 300, 1100, 2048], device="cuda", dtype=torch.int32)

attention_kernel = DotProductAttention(heads, kv_channels)

os.environ["NVTE_FUSED_ATTN"] = "0"
os.environ["NVTE_FLASH_ATTN"] = "1"
output_flash = attention_kernel(q, k, v, qkv_format='thd', attn_mask_type='padding', cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv)

os.environ["NVTE_FUSED_ATTN"] = "1"
os.environ["NVTE_FLASH_ATTN"] = "0"
_attention_backends["backend_selection_requires_update"] = True
output_fused = attention_kernel(q, k, v, qkv_format='thd', attn_mask_type='padding', cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv)

torch.testing.assert_close(output_fused, output_flash, atol=1e-2, rtol=1e-2)

Run:

NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 python test.py 
/code/pr-thd-int64/TransformerEngine/transformer_engine/pytorch/attention.py:5162: UserWarning: window_size should be (-1, -1) or (>=0, >=0) for attn_mask_type=padding
  warnings.warn(
[INFO     | DotProductAttention]: Running with FlashAttention backend (version 2.4.2)
[INFO     | DotProductAttention]: Running with FusedAttention backend (sub-backend 1)
Marks101 commented 3 weeks ago

Hey Hey,

oh, okay, then my use case was simply not correct 🙈 thank you so much for your explanation and for extending the documentation. That was exactly the information that I was looking for!