Dao-AILab / flash-attention

Fast and memory-efficient exact attention
BSD 3-Clause "New" or "Revised" License
13.47k stars 1.23k forks source link

Avoid padding computation with `cu_seqlens` #1228

Open imoneoi opened 1 week ago

imoneoi commented 1 week ago

To work with torch.compile which is more efficient on static shapes, I pad some tokens at the end to make the shape of q,k,v static, e.g. [N, D].

Can I set the last element in cu_seqlens of varlen API to be less than N to avoid computing the padding? Also, is the backward pass accurate in this case?

tridao commented 1 week ago

Yes I think that should work. You should test that still

imoneoi commented 6 days ago

Thanks! I have tested the kernel and it does work. However, the padding elements may be uninitialized, resulting in NaN/inf in the forward and backward passes. Can we include a fix to simply zero these elements?

imoneoi commented 6 days ago

BTW, here is the code used for testing:

from typing import Any
import torch

from tqdm import tqdm
from flash_attn import flash_attn_varlen_func

def test_flash_attn_padding(
    seed: int = 0,
    test_rounds: int = 10,
    num_heads: int = 8,
    head_size: int = 64,
    seq_len: int = 160,
    batch_size: int = 131_072,

    dtype: Any = torch.bfloat16
):
    torch.manual_seed(seed)
    torch.set_default_device("cuda")

    # Construct testdata
    q = torch.randn((batch_size, num_heads, head_size), dtype=dtype)
    k = torch.randn((batch_size, num_heads, head_size), dtype=dtype)
    v = torch.randn((batch_size, num_heads, head_size), dtype=dtype)

    seqlens = torch.cat([
        torch.full((batch_size // seq_len, ), seq_len, dtype=torch.int32),
        torch.full((1, ), batch_size % seq_len, dtype=torch.int32)
    ])

    cu_seqlens = torch.nn.functional.pad(seqlens.cumsum(-1, dtype=seqlens.dtype), (1, 0))
    max_seqlen = seqlens.max()

    # Multiple rounds so that torch.empty() might be filled with random value
    for round in tqdm(range(test_rounds)):
        # Fwd
        gt_out    = flash_attn_varlen_func(q, k, v, cu_seqlens_q=cu_seqlens,      cu_seqlens_k=cu_seqlens,      max_seqlen_q=max_seqlen, max_seqlen_k=max_seqlen)
        nopad_out = flash_attn_varlen_func(q, k, v, cu_seqlens_q=cu_seqlens[:-1], cu_seqlens_k=cu_seqlens[:-1], max_seqlen_q=max_seqlen, max_seqlen_k=max_seqlen)

        assert torch.allclose(nopad_out[:cu_seqlens[-2]], gt_out[:cu_seqlens[-2]])
        # assert torch.allclose(nopad_out[cu_seqlens[-2]:], torch.zeros((), dtype=dtype))  # Might be empty here

        # Bwd
        # ground truth
        dgrad = torch.randn((batch_size, num_heads, head_size), dtype=dtype)
        q.requires_grad_()
        k.requires_grad_()
        v.requires_grad_()

        gt_out = flash_attn_varlen_func(q, k, v, cu_seqlens_q=cu_seqlens,      cu_seqlens_k=cu_seqlens,      max_seqlen_q=max_seqlen, max_seqlen_k=max_seqlen)
        (dgrad * gt_out).sum().backward()

        gt_dq = q.grad
        gt_dk = k.grad
        gt_dv = v.grad
        q.grad = None
        k.grad = None
        v.grad = None

        # unpadded
        nopad_out = flash_attn_varlen_func(q, k, v, cu_seqlens_q=cu_seqlens[:-1], cu_seqlens_k=cu_seqlens[:-1], max_seqlen_q=max_seqlen, max_seqlen_k=max_seqlen) 
        (dgrad * nopad_out).sum().backward()

        assert torch.allclose(q.grad[:cu_seqlens[-2]], gt_dq[:cu_seqlens[-2]])
        # assert torch.allclose(q.grad[cu_seqlens[-2]:], torch.zeros((), dtype=dtype))  # Might be empty here

        assert torch.allclose(k.grad[:cu_seqlens[-2]], gt_dk[:cu_seqlens[-2]])
        # assert torch.allclose(k.grad[cu_seqlens[-2]:], torch.zeros((), dtype=dtype))  # Might be empty here

        assert torch.allclose(v.grad[:cu_seqlens[-2]], gt_dv[:cu_seqlens[-2]])
        # assert torch.allclose(v.grad[cu_seqlens[-2]:], torch.zeros((), dtype=dtype))  # Might be empty here

        q.grad = None
        k.grad = None
        v.grad = None

if __name__ == "__main__":
    test_flash_attn_padding()