Open drisspg opened 1 month ago
For more details see this PyTorch issue: https://github.com/pytorch/pytorch/issues/131257
I was able to reproduce on: 74b0761ff7efc7b90d4e5aeb529c1b2a09a7458c with following script:
import torch import torch.nn.functional as F from torch.nn.attention import sdpa_kernel, SDPBackend from flash_attn import flash_attn_func def test_flash_attention(): device = "cuda" dtype = torch.bfloat16 batch_size = 1 n_heads = 1 seq_len_q = 1 seq_len_k = 257 head_dim = 32 is_causal = False dropout_p = 0.0 scale = 1 / head_dim # Create input tensors query = torch.rand( batch_size, n_heads, seq_len_q, head_dim, device=device, dtype=dtype, requires_grad=True, ) key = torch.rand( batch_size, n_heads, seq_len_k, head_dim, device=device, dtype=dtype, requires_grad=True, ) value = torch.rand( batch_size, n_heads, seq_len_k, head_dim, device=device, dtype=dtype, requires_grad=True, ) out_flash = flash_attn_func( query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2), ) out_flash = out_flash.transpose(1, 2) # with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]): # out = F.scaled_dot_product_attention( # query, key, value, dropout_p=dropout_p, is_causal=is_causal, scale=scale # ) print("Output flash shape: ", out_flash.shape) # print("Output shape:", out.shape) # print("Output sum:", out.sum().item()) if __name__ == "__main__": test_flash_attention()
TORCH_DISABLE_ADDR2LINE=1 PYTORCH_NO_CUDA_MEMORY_CACHING=1 compute-sanitizer --tool memcheck --log-file ima.txt python ima.py
Thanks for the bug report. I can reproduce the error now.
Also my current best guess of what might be going on: https://github.com/pytorch/pytorch/pull/131277
Summary
For more details see this PyTorch issue: https://github.com/pytorch/pytorch/issues/131257
I was able to reproduce on: 74b0761ff7efc7b90d4e5aeb529c1b2a09a7458c with following script: