Dao-AILab / flash-attention

Fast and memory-efficient exact attention
BSD 3-Clause "New" or "Revised" License
13.85k stars 1.28k forks source link

Incorrect "RuntimeError: FlashAttention only support fp16 and bf16 data type" #915

Open jlamypoirier opened 6 months ago

jlamypoirier commented 6 months ago

Flash attn 2.5.7 always complains about the input data type even when it's clearly a correct one. I'm using the base image nvcr.io/nvidia/pytorch:24.03-py3

>>> import torch, flash_attn
>>> from flash_attn.flash_attn_interface import flash_attn_func
>>> x=torch.randn(1, 4096, 8, 128, dtype=torch.bfloat16, device="cuda")
>>> flash_attn.__version__
'2.5.7'
>>> flash_attn_func(x,x,x)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/user/.local/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 831, in flash_attn_func
    return FlashAttnFunc.apply(
  File "/usr/local/lib/python3.10/dist-packages/torch/autograd/function.py", line 572, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/home/user/.local/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 511, in forward
    out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
  File "/home/user/.local/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 51, in _flash_attn_forward
    out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(
RuntimeError: FlashAttention only support fp16 and bf16 data type
Exception raised from mha_fwd at /home/runner/work/flash-attention/flash-attention/csrc/flash_attn/flash_api.cpp:340 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0x99 (0x7fffec083d89 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, char const*) + 0x6a (0x7fffec0335ac in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #2: mha_fwd(at::Tensor&, at::Tensor const&, at::Tensor const&, std::optional<at::Tensor>&, std::optional<at::Tensor>&, float, float, bool, int, int, bool, std::optional<at::Generator>) + 0x18e9 (0x7ffea50183d9 in /home/user/.local/lib/python3.10/site-packages/flash_attn_2_cuda.cpython-310-x86_64-linux-gnu.so)
frame #3: <unknown function> + 0x1391c9 (0x7ffea50341c9 in /home/user/.local/lib/python3.10/site-packages/flash_attn_2_cuda.cpython-310-x86_64-linux-gnu.so)
frame #4: <unknown function> + 0x135819 (0x7ffea5030819 in /home/user/.local/lib/python3.10/site-packages/flash_attn_2_cuda.cpython-310-x86_64-linux-gnu.so)
<omitting python frames>
frame #11: THPFunction_apply(_object*, _object*) + 0xf59 (0x7fffeb26c0c9 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_python.so)
frame #29: <unknown function> + 0x29d90 (0x7ffff7a00d90 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #30: __libc_start_main + 0x80 (0x7ffff7a00e40 in /usr/lib/x86_64-linux-gnu/libc.so.6)
tridao commented 6 months ago

Thanks for the report, I can reproduce it. Investigating now. Might be because of the way torch (in C++) handle dtype.

tridao commented 6 months ago

Hmm compiling from scratch seems to work fine, so sth is wrong about the wheel we built.

tridao commented 6 months ago

I'm guessing this is because 24.03 uses CUDA 12.4 and the wheels built with nvcc 12.2 are somehow not compatible.

FremyCompany commented 4 months ago

What is the recommended fix then? Rebuild flash attention from code?

FremyCompany commented 4 months ago

I just did a fresh install using python setup.py install, but I still get the same error message. Either I did something wrong, or there is still a problem somewhere.

  File "/root/.cache/huggingface/modules/transformers_modules/microsoft/Phi-3-small-128k-instruct/f80aaa30bfc64c2b8ab214b541d9050e97163bc4/modeling_phi3_small.py", line 624, in forward
    attn_function_output = self._apply_dense_attention(
  File "/root/.cache/huggingface/modules/transformers_modules/microsoft/Phi-3-small-128k-instruct/f80aaa30bfc64c2b8ab214b541d9050e97163bc4/modeling_phi3_small.py", line 441, in _apply_dense_attention
    attn_output_unpad = flash_attn_varlen_kvpacked_func(
  File "/usr/local/lib/python3.10/dist-packages/flash_attn-2.5.9.post1-py3.10-linux-aarch64.egg/flash_attn/flash_attn_interface.py", line 978, in flash_attn_varlen_kvpacked_func
    return FlashAttnVarlenKVPackedFunc.apply(
  File "/usr/local/lib/python3.10/dist-packages/torch/autograd/function.py", line 572, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/usr/local/lib/python3.10/dist-packages/flash_attn-2.5.9.post1-py3.10-linux-aarch64.egg/flash_attn/flash_attn_interface.py", line 432, in forward
    out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
  File "/usr/local/lib/python3.10/dist-packages/flash_attn-2.5.9.post1-py3.10-linux-aarch64.egg/flash_attn/flash_attn_interface.py", line 86, in _flash_attn_varlen_forward
    out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd(
RuntimeError: FlashAttention only support fp16 and bf16 data type

Do you think I can comment the code that does the check, and reinstall, or is that just a sign that things will break down the line?

FremyCompany commented 4 months ago

Ok, it sounds like the issue was with the Phi3-small code, not with the library. The assert was triggering appropriately. Sorry for the noise.

mainrs commented 3 months ago

Ok, it sounds like the issue was with the Phi3-small code, not with the library. The assert was triggering appropriately. Sorry for the noise.

How did you solve the issue? I am also trying to fine-tune phi3. But I receive the same error message and I don't know what to do with it tbh.

amulil commented 3 months ago

meet the same problem @tridao do you have the recommended way to solve it?

FremyCompany commented 1 month ago

Sorry I forgot, but I think it was that you had to set a value for use rentrant other than the default.