Dao-AILab / flash-attention

Fast and memory-efficient exact attention
BSD 3-Clause "New" or "Revised" License
14.47k stars 1.36k forks source link

Varlen flash attention: CUDA illegal memory access #1311

Open clessig opened 1 month ago

clessig commented 1 month ago

I obtain the following error when when my length of chunks/batches becomes large:

File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/flash_attn-2.6.3-py3.12-linux-x86_64.egg/flash_attn/flash_attn_interface.py", line 198, in _flash_attn_varlen_backward ) = flash_attn_cuda.varlen_bwd( ^^^^^^^^^^^^^^^^^^^^^^^^^^^ RuntimeError: CUDA error: an illegal memory access was encountered

Is it possible that there is an implicit max length for the number of chunks/batches that is not covered by checks (potentially with some memory space running out)?

tridao commented 1 month ago

It's possible. We use 32bit indexing so when tensors get larger than 2GB or 4GB the indexing might be wrong. Can you help us reproduce the error, e.g. with a short script?

clessig commented 1 month ago

I just tried to write a small repo case with just one MHA-Varlen but couldn't reproduce it.

Is it possible that the error depends on the entire graph for my real-world network?

tridao commented 1 month ago

If you can save the tensors (q, k, v, and gradient) that caused the IMA you can load them back up in a script.

clessig commented 1 month ago

I captured the state in flash_attn/flash_attn_interface.py::195 before the call to flash_attn_cuda.varlen_bwd() with

terminate called after throwing an instance of 'c10::Error'
  what():  CUDA error: an illegal memory access was encountered

(I tried to trace the error a bit through the C++/CUDA code but didn't get to the end). I capture the state with (some implicit casts to fp32 but these should be obvious and irrelevant):

np.save( 'q.npy', q.detach().cpu().numpy())
np.save( 'k.npy', k.detach().cpu().numpy())
np.save( 'v.npy', v.detach().cpu().numpy())
np.save( 'out.npy', out.detach().cpu().numpy())
np.save( 'softmax_lse.npy', softmax_lse.detach().cpu().numpy())
np.save( 'dq.npy', dq.detach().cpu().numpy())
np.save( 'dk.npy', dk.detach().cpu().numpy())
np.save( 'dv.npy', dv.detach().cpu().numpy())
np.save( 'cu_seqlens_q.npy', cu_seqlens_q.detach().cpu().numpy())
np.save( 'cu_seqlens_k.npy', cu_seqlens_k.detach().cpu().numpy())
np.save( 'max_seqlen_q.npy', max_seqlen_q.detach().cpu().numpy())
np.save( 'max_seqlen_k.npy', max_seqlen_q.detach().cpu().numpy())

and the rest of the parameters to flash_attn_cuda.varlen_bwd is:

alibi_slopes = None
dropout_p = 0.0
softmax_scale = 0.08838834764831845
causal = False
window_size = (-1, -1)
softcap = 0.0
deterministic = False
rng_state = tensor([0, 0], device='cuda:0')

The data is here: http://graphics.cs.uni-magdeburg.de/misc/flash_attn_repro.zip (\approx 3.12 GB).

Let me know if anything else is needed!

tridao commented 1 month ago

When you load them back up and run the forward & backward do you get IMA?

clessig commented 1 month ago

Didn't get to this yet ... But there would be no reason to run forward? One could inject the value directly again into flash_attn_cuda.varlen_bwd, not?

tridao commented 4 weeks ago

yeah however you run it, as long as there's a script to load up the tensors and reproduce the IMA that'd be most helpful

MARD1NO commented 4 weeks ago

It's possible. We use 32bit indexing so when tensors get larger than 2GB or 4GB the indexing might be wrong. Can you help us reproduce the error, e.g. with a short script?

Can we add a IndexType dispatch for flash attention2? I also encounter this problem and it takes me a long time to find the bug... Hopper GPU has larger GPU Memory, which can allocate large kvcache and it can be easily triggered this bug

clessig commented 4 weeks ago

@tridao : here's the repro case


import numpy as np
import torch

from flash_attn.flash_attn_interface import _flash_attn_varlen_backward

dout = torch.from_numpy( np.load('dout.npy')).to(torch.float16).to('cuda')
q = torch.from_numpy( np.load('q.npy')).to(torch.float16).to('cuda')
k = torch.from_numpy( np.load('k.npy')).to(torch.float16).to('cuda')
v = torch.from_numpy( np.load('v.npy')).to(torch.float16).to('cuda')
out = torch.from_numpy( np.load('out.npy')).to(torch.float16).to('cuda')
softmax_lse = torch.from_numpy( np.load('softmax_lse.npy')).to(torch.float16).to('cuda')
dq = torch.from_numpy( np.load('dq.npy')).to(torch.float16).to('cuda')
dk = torch.from_numpy( np.load('dk.npy')).to(torch.float16).to('cuda')
dv = torch.from_numpy( np.load('dv.npy')).to(torch.float16).to('cuda')
cu_seqlens_q = torch.from_numpy( np.load('cu_seqlens_q.npy')).to(torch.int32).to('cuda')
cu_seqlens_k = torch.from_numpy( np.load('cu_seqlens_k.npy')).to(torch.int32).to('cuda')
max_seqlen_q = torch.from_numpy( np.load('max_seqlen_q.npy')).to(torch.int32).to('cuda')
max_seqlen_k = torch.from_numpy( np.load('max_seqlen_k.npy')).to(torch.int32).to('cuda')

alibi_slopes = None
dropout_p = 0.0
softmax_scale = 0.08838834764831845
causal = False
window_size = (-1, -1)
softcap = 0.0
deterministic = False
rng_state = torch.tensor([0, 0], device='cuda')

ret = _flash_attn_varlen_backward( dout,
                                    q,
                                    k,
                                    v,
                                    out,
                                    softmax_lse,
                                    dq,
                                    dk,
                                    dv,
                                    cu_seqlens_q,
                                    cu_seqlens_k,
                                    max_seqlen_q,
                                    max_seqlen_k,
                                    dropout_p,
                                    softmax_scale,
                                    causal,
                                    window_size,
                                    softcap,
                                    alibi_slopes,
                                    deterministic,
                                    rng_state
                                )

Output:

Traceback (most recent call last):
  File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/flash_attn_repro.py", line 30, in <module>
    ret = _flash_attn_varlen_backward( dout,
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/flash_attn-2.6.3-py3.12-linux-x86_64.egg/flash_attn/flash_attn_interface.py", line 234, in _flash_attn_varlen_backward
    window_size[0],
    ~~~~~~~~~~~^^^
TypeError: 'bool' object is not subscriptable
(pyenv312) ecm327663@as01r4b06:ai-obs-experimental-transformer$ python flash_attn_repro.py
WRITTEN
(pyenv312) ecm327663@as01r4b06:ai-obs-experimental-transformer$ python flash_attn_repro.py
Traceback (most recent call last):
  File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/flash_attn_repro.py", line 30, in <module>
    ret = _flash_attn_varlen_backward( dout,
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/flash_attn-2.6.3-py3.12-linux-x86_64.egg/flash_attn/flash_attn_interface.py", line 215, in _flash_attn_varlen_backward
    ) = flash_attn_cuda.varlen_bwd(
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

(pyenv312) ecm327663@as01r4b06:ai-obs-experimental-transformer$ CUDA_LAUNCH_BLOCKING=1 python flash_attn_repro.py
Traceback (most recent call last):
  File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/flash_attn_repro.py", line 30, in <module>
    ret = _flash_attn_varlen_backward( dout,
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/flash_attn-2.6.3-py3.12-linux-x86_64.egg/flash_attn/flash_attn_interface.py", line 215, in _flash_attn_varlen_backward
    ) = flash_attn_cuda.varlen_bwd(
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: CUDA error: an illegal memory access was encountered
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

The data is here for download (small update): http://graphics.cs.uni-magdeburg.de/misc/flash_attn_repro.zip.

Let me know if you need anything else.

zxgx commented 3 weeks ago

Hi @tridao , I met a similar issue, and you don't have to read any specific data to reproduce it with the following script:

import torch

from flash_attn import flash_attn_func

shape = [28800, 15, 16, 72]
device = torch.device("cuda:0")
dtype = torch.bfloat16
q = torch.randn(shape, device=device, dtype=dtype, requires_grad=True)
k = torch.randn(shape, device=device, dtype=dtype, requires_grad=True)
v = torch.randn(shape, device=device, dtype=dtype, requires_grad=True)
grad = torch.randn(shape, device=device, dtype=dtype)

print(f"{torch.cuda.memory_allocated()/1024**3}, {torch.cuda.max_memory_allocated()/1024**3}")

o = flash_attn_func(
    q,
    k,
    v,
    dropout_p=0.0,
)

o.backward(grad)
torch.cuda.synchronize()
print(f"{torch.cuda.memory_allocated()}, {torch.cuda.max_memory_allocated()}")
3.7109375, 3.7109375
Traceback (most recent call last):
  File "run_kernel.py", line 23, in <module>
    o.backward(grad)
  File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/_tensor.py", line 492, in backward
    torch.autograd.backward(
  File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/autograd/__init__.py", line 251, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/autograd/function.py", line 288, in apply
    return user_fn(self, *args)
  File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 572, in backward
    _flash_attn_backward(
  File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 144, in _flash_attn_backward
    ) = flash_attn_cuda.bwd(
RuntimeError: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

environment:

I guess it's indeed due to the large tensor size, because when I switch to batch size < 28800, there isn't any error. Can you please tell me how to fix it, or do you intend to fix it in flash-attn? Many thanks.

clessig commented 1 week ago

I wanted to look myself a bit deeper into the issue (since I have quite some experience with flash attention). With cuda-gdb I get to:

CUDA Exception: Warp Illegal Address

Thread 137 "pt_autograd_0" received signal CUDA_EXCEPTION_14, Warp Illegal Address.
[Switching focus to CUDA kernel 0, grid 2749, block (0,8156,0), thread (96,0,0), device 0, sm 0, warp 10, lane 0]
0x00007ff9ac5f8fa0 in void flash_bwd_dot_do_o_kernel<true, Flash_bwd_kernel_traits<128, 64, 128, 8, 2, 4, 2, false, false, cutlass::half_t, Flash_kernel_traits<128, 64, 128, 8, cutlass::half_t> > >(Flash_bwd_params)
   <<<(1,12288,16),(256,1,1)>>> ()

but all the templates seem to have been inlined beyond this point so that I cannot easily discern in more detail where the problem is. I therefore wanted to compile flash attention in debug mode, i.e. with "-g -G". However, this fails. The relevant part of the output

Compiling objects...
Using envvar MAX_JOBS (1) as the number of workers...

[1/84] /apps/ACC/CUDA/12.3/bin/nvcc --generate-dependencies-with-compile --dependency-output /gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/scratch/flash-attention/build/temp.linux-x86_64-cpython-312/csrc/flash_attn/src/flash_bwd_hdim128_bf16_causal_sm80.o.d -I/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/scratch/flash-attention/csrc/flash_attn -I/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/scratch/flash-attention/csrc/flash_attn/src -I/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/scratch/flash-attention/csrc/cutlass/include -I/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312_debug/lib/python3.12/site-packages/torch/include -I/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312_debug/lib/python3.12/site-packages/torch/include/torch/csrc/api/include -I/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312_debug/lib/python3.12/site-packages/torch/include/TH -I/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312_debug/lib/python3.12/site-packages/torch/include/THC -I/apps/ACC/CUDA/12.3/include -I/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312_debug/include -I/apps/ACC/PYTHON/3.12.1/INTEL/include/python3.12 -c -c /gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/scratch/flash-attention/csrc/flash_attn/src/flash_bwd_hdim128_bf16_causal_sm80.cu -o /gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/scratch/flash-attention/build/temp.linux-x86_64-cpython-312/csrc/flash_attn/src/flash_bwd_hdim128_bf16_causal_sm80.o -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options ''"'"'-fPIC'"'"'' -g -G -std=c++17 -U__CUDA_NO_HALF_OPERATORS__ -U__CUDA_NO_HALF_CONVERSIONS__ -U__CUDA_NO_HALF2_OPERATORS__ -U__CUDA_NO_BFLOAT16_CONVERSIONS__ --expt-relaxed-constexpr --expt-extended-lambda --use_fast_math -gencode arch=compute_80,code=sm_80 -gencode arch=compute_90,code=sm_90 --threads 4 -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1011"' -DTORCH_EXTENSION_NAME=flash_attn_2_cuda -D_GLIBCXX_USE_CXX11_ABI=0 -ccbin gcc
FAILED: /gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/scratch/flash-attention/build/temp.linux-x86_64-cpython-312/csrc/flash_attn/src/flash_bwd_hdim128_bf16_causal_sm80.o 
/apps/ACC/CUDA/12.3/bin/nvcc --generate-dependencies-with-compile --dependency-output /gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/scratch/flash-attention/build/temp.linux-x86_64-cpython-312/csrc/flash_attn/src/flash_bwd_hdim128_bf16_causal_sm80.o.d -I/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/scratch/flash-attention/csrc/flash_attn -I/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/scratch/flash-attention/csrc/flash_attn/src -I/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/scratch/flash-attention/csrc/cutlass/include -I/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312_debug/lib/python3.12/site-packages/torch/include -I/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312_debug/lib/python3.12/site-packages/torch/include/torch/csrc/api/include -I/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312_debug/lib/python3.12/site-packages/torch/include/TH -I/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312_debug/lib/python3.12/site-packages/torch/include/THC -I/apps/ACC/CUDA/12.3/include -I/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312_debug/include -I/apps/ACC/PYTHON/3.12.1/INTEL/include/python3.12 -c -c /gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/scratch/flash-attention/csrc/flash_attn/src/flash_bwd_hdim128_bf16_causal_sm80.cu -o /gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/scratch/flash-attention/build/temp.linux-x86_64-cpython-312/csrc/flash_attn/src/flash_bwd_hdim128_bf16_causal_sm80.o -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options ''"'"'-fPIC'"'"'' -g -G -std=c++17 -U__CUDA_NO_HALF_OPERATORS__ -U__CUDA_NO_HALF_CONVERSIONS__ -U__CUDA_NO_HALF2_OPERATORS__ -U__CUDA_NO_BFLOAT16_CONVERSIONS__ --expt-relaxed-constexpr --expt-extended-lambda --use_fast_math -gencode arch=compute_80,code=sm_80 -gencode arch=compute_90,code=sm_90 --threads 4 -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1011"' -DTORCH_EXTENSION_NAME=flash_attn_2_cuda -D_GLIBCXX_USE_CXX11_ABI=0 -ccbin gcc
Warning: Function too large, generated debug information may not be accurate.

Warning: Function too large, generated debug information may not be accurate.

ninja: build stopped: subcommand failed.
Traceback (most recent call last):
  File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312_debug/lib/python3.12/site-packages/torch/utils/cpp_extension.py", line 2104, in _run_ninja_build
    subprocess.run(
  File "/apps/ACC/PYTHON/3.12.1/INTEL/lib/python3.12/subprocess.py", line 571, in run
    raise CalledProcessError(retcode, process.args,
subprocess.CalledProcessError: Command '['ninja', '-v', '-j', '1']' returned non-zero exit status 1.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/scratch/flash-attention/setup.py", line 492, in <module>
    setup(
  File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312_debug/lib/python3.12/site-packages/setuptools/__init__.py", line 117, in setup
    return distutils.core.setup(**attrs)

Compiling the same directly in the shell works:

%>/apps/ACC/CUDA/12.3/bin/nvcc --generate-dependencies-with-compile --dependency-output /gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/scratch/flash-attention/build/temp.linux-x86_64-cpython-312/csrc/flash_attn/src/flash_bwd_hdim128_bf16_causal_sm80.o.d -I/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/scratch/flash-attention/csrc/flash_attn -I/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/scratch/flash-attention/csrc/flash_attn/src -I/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/scratch/flash-attention/csrc/cutlass/include -I/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312_debug/lib/python3.12/site-packages/torch/include -I/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312_debug/lib/python3.12/site-packages/torch/include/torch/csrc/api/include -I/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312_debug/lib/python3.12/site-packages/torch/include/TH -I/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312_debug/lib/python3.12/site-packages/torch/include/THC -I/apps/ACC/CUDA/12.3/include -I/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312_debug/include -I/apps/ACC/PYTHON/3.12.1/INTEL/include/python3.12 -c -c /gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/scratch/flash-attention/csrc/flash_attn/src/flash_bwd_hdim128_bf16_causal_sm80.cu -o /gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/scratch/flash-attention/build/temp.linux-x86_64-cpython-312/csrc/flash_attn/src/flash_bwd_hdim128_bf16_causal_sm80.o -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options ''"'"'-fPIC'"'"'' -g -G -std=c++17 -U__CUDA_NO_HALF_OPERATORS__ -U__CUDA_NO_HALF_CONVERSIONS__ -U__CUDA_NO_HALF2_OPERATORS__ -U__CUDA_NO_BFLOAT16_CONVERSIONS__ --expt-relaxed-constexpr --expt-extended-lambda --use_fast_math -gencode arch=compute_80,code=sm_80 -gencode arch=compute_90,code=sm_90 --threads 4 -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1011"' -DTORCH_EXTENSION_NAME=flash_attn_2_cuda -D_GLIBCXX_USE_CXX11_ABI=0 -ccbin gcc
Warning: Function too large, generated debug information may not be accurate.
Warning: Function too large, generated debug information may not be accurate.
%>

I am wondering if the warning is interpreted by distutils as an error? Is there a known way to compile flash attention?