Dao-AILab / flash-attention

Fast and memory-efficient exact attention
BSD 3-Clause "New" or "Revised" License
14k stars 1.3k forks source link

Varlen flash attention: CUDA illegal memory access #1311

Open clessig opened 5 hours ago

clessig commented 5 hours ago

I obtain the following error when when my length of chunks/batches becomes large:

File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/flash_attn-2.6.3-py3.12-linux-x86_64.egg/flash_attn/flash_attn_interface.py", line 198, in _flash_attn_varlen_backward ) = flash_attn_cuda.varlen_bwd( ^^^^^^^^^^^^^^^^^^^^^^^^^^^ RuntimeError: CUDA error: an illegal memory access was encountered

Is it possible that there is an implicit max length for the number of chunks/batches that is not covered by checks (potentially with some memory space running out)?

tridao commented 5 hours ago

It's possible. We use 32bit indexing so when tensors get larger than 2GB or 4GB the indexing might be wrong. Can you help us reproduce the error, e.g. with a short script?

clessig commented 3 hours ago

I just tried to write a small repo case with just one MHA-Varlen but couldn't reproduce it.

Is it possible that the error depends on the entire graph for my real-world network?

tridao commented 3 hours ago

If you can save the tensors (q, k, v, and gradient) that caused the IMA you can load them back up in a script.