Open imoneoi opened 1 week ago
Yes I think that should work. You should test that still
Thanks! I have tested the kernel and it does work. However, the padding elements may be uninitialized, resulting in NaN/inf in the forward and backward passes. Can we include a fix to simply zero these elements?
BTW, here is the code used for testing:
from typing import Any
import torch
from tqdm import tqdm
from flash_attn import flash_attn_varlen_func
def test_flash_attn_padding(
seed: int = 0,
test_rounds: int = 10,
num_heads: int = 8,
head_size: int = 64,
seq_len: int = 160,
batch_size: int = 131_072,
dtype: Any = torch.bfloat16
):
torch.manual_seed(seed)
torch.set_default_device("cuda")
# Construct testdata
q = torch.randn((batch_size, num_heads, head_size), dtype=dtype)
k = torch.randn((batch_size, num_heads, head_size), dtype=dtype)
v = torch.randn((batch_size, num_heads, head_size), dtype=dtype)
seqlens = torch.cat([
torch.full((batch_size // seq_len, ), seq_len, dtype=torch.int32),
torch.full((1, ), batch_size % seq_len, dtype=torch.int32)
])
cu_seqlens = torch.nn.functional.pad(seqlens.cumsum(-1, dtype=seqlens.dtype), (1, 0))
max_seqlen = seqlens.max()
# Multiple rounds so that torch.empty() might be filled with random value
for round in tqdm(range(test_rounds)):
# Fwd
gt_out = flash_attn_varlen_func(q, k, v, cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens, max_seqlen_q=max_seqlen, max_seqlen_k=max_seqlen)
nopad_out = flash_attn_varlen_func(q, k, v, cu_seqlens_q=cu_seqlens[:-1], cu_seqlens_k=cu_seqlens[:-1], max_seqlen_q=max_seqlen, max_seqlen_k=max_seqlen)
assert torch.allclose(nopad_out[:cu_seqlens[-2]], gt_out[:cu_seqlens[-2]])
# assert torch.allclose(nopad_out[cu_seqlens[-2]:], torch.zeros((), dtype=dtype)) # Might be empty here
# Bwd
# ground truth
dgrad = torch.randn((batch_size, num_heads, head_size), dtype=dtype)
q.requires_grad_()
k.requires_grad_()
v.requires_grad_()
gt_out = flash_attn_varlen_func(q, k, v, cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens, max_seqlen_q=max_seqlen, max_seqlen_k=max_seqlen)
(dgrad * gt_out).sum().backward()
gt_dq = q.grad
gt_dk = k.grad
gt_dv = v.grad
q.grad = None
k.grad = None
v.grad = None
# unpadded
nopad_out = flash_attn_varlen_func(q, k, v, cu_seqlens_q=cu_seqlens[:-1], cu_seqlens_k=cu_seqlens[:-1], max_seqlen_q=max_seqlen, max_seqlen_k=max_seqlen)
(dgrad * nopad_out).sum().backward()
assert torch.allclose(q.grad[:cu_seqlens[-2]], gt_dq[:cu_seqlens[-2]])
# assert torch.allclose(q.grad[cu_seqlens[-2]:], torch.zeros((), dtype=dtype)) # Might be empty here
assert torch.allclose(k.grad[:cu_seqlens[-2]], gt_dk[:cu_seqlens[-2]])
# assert torch.allclose(k.grad[cu_seqlens[-2]:], torch.zeros((), dtype=dtype)) # Might be empty here
assert torch.allclose(v.grad[:cu_seqlens[-2]], gt_dv[:cu_seqlens[-2]])
# assert torch.allclose(v.grad[cu_seqlens[-2]:], torch.zeros((), dtype=dtype)) # Might be empty here
q.grad = None
k.grad = None
v.grad = None
if __name__ == "__main__":
test_flash_attn_padding()
To work with
torch.compile
which is more efficient on static shapes, I pad some tokens at the end to make the shape ofq,k,v
static, e.g.[N, D]
.Can I set the last element in
cu_seqlens
of varlen API to be less thanN
to avoid computing the padding? Also, is the backward pass accurate in this case?