Dao-AILab / flash-attention

Fast and memory-efficient exact attention
BSD 3-Clause "New" or "Revised" License
13.3k stars 1.2k forks source link

IMA with split k kernel. #1076

Open drisspg opened 1 month ago

drisspg commented 1 month ago

Summary

For more details see this PyTorch issue: https://github.com/pytorch/pytorch/issues/131257

I was able to reproduce on: 74b0761ff7efc7b90d4e5aeb529c1b2a09a7458c with following script:

import torch
import torch.nn.functional as F
from torch.nn.attention import sdpa_kernel, SDPBackend

from flash_attn import flash_attn_func

def test_flash_attention():
    device = "cuda"
    dtype = torch.bfloat16
    batch_size = 1
    n_heads = 1
    seq_len_q = 1
    seq_len_k = 257
    head_dim = 32
    is_causal = False
    dropout_p = 0.0
    scale = 1 / head_dim

    # Create input tensors
    query = torch.rand(
        batch_size,
        n_heads,
        seq_len_q,
        head_dim,
        device=device,
        dtype=dtype,
        requires_grad=True,
    )
    key = torch.rand(
        batch_size,
        n_heads,
        seq_len_k,
        head_dim,
        device=device,
        dtype=dtype,
        requires_grad=True,
    )
    value = torch.rand(
        batch_size,
        n_heads,
        seq_len_k,
        head_dim,
        device=device,
        dtype=dtype,
        requires_grad=True,
    )

    out_flash = flash_attn_func(
        query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2),
    )
    out_flash = out_flash.transpose(1, 2)
    # with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
    #     out = F.scaled_dot_product_attention(
    #         query, key, value, dropout_p=dropout_p, is_causal=is_causal, scale=scale
    #     )

    print("Output flash shape: ", out_flash.shape)
    # print("Output shape:", out.shape)
    # print("Output sum:", out.sum().item())

if __name__ == "__main__":
    test_flash_attention()
TORCH_DISABLE_ADDR2LINE=1 PYTORCH_NO_CUDA_MEMORY_CACHING=1  compute-sanitizer --tool memcheck --log-file ima.txt python ima.py 
tridao commented 1 month ago

Thanks for the bug report. I can reproduce the error now.

drisspg commented 1 month ago

Also my current best guess of what might be going on: https://github.com/pytorch/pytorch/pull/131277