Dao-AILab / flash-attention

Fast and memory-efficient exact attention
BSD 3-Clause "New" or "Revised" License
14.34k stars 1.35k forks source link

Issue with installing flash attention ` import flash_attn_2_cuda as flash_attn_cuda` #1348

Open hahmad2008 opened 2 days ago

hahmad2008 commented 2 days ago

Gemma2 need torch>=2.4.0 as this mentioned Because when I run it I get this error:

  File "/usr/local/lib/python3.10/dist-packages/transformers/cache_utils.py", line 1656, in __init__
    torch._dynamo.mark_static_address(new_layer_key_cache)
AttributeError: module 'torch._dynamo' has no attribute 'mark_static_address'

So this need torch>=2.4.0 but the current version is the following:

>>> import torch;torch.__version__
'2.0.1+cu117'
>>> import flash_attn;flash_attn.__version__
'2.5.6'

The problem is when I tried to install torch with this version '2.4.0+cu118' while I have

root@0d6c1aeee409:/space/LongLM# nvcc -V
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2022 NVIDIA Corporation
Built on Wed_Sep_21_10:33:58_PDT_2022
Cuda compilation tools, release 11.8, V11.8.89
Build cuda_11.8.r11.8/compiler.31833905_0

I got this error:

>>> import flash_attn;flash_attn.__version__
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/usr/local/lib/python3.10/dist-packages/flash_attn/__init__.py", line 3, in <module>
    from flash_attn.flash_attn_interface import (
  File "/usr/local/lib/python3.10/dist-packages/flash_attn/flash_attn_interface.py", line 10, in <module>
    import flash_attn_2_cuda as flash_attn_cuda
ImportError: /usr/local/lib/python3.10/dist-packages/flash_attn_2_cuda.cpython-310-x86_64-linux-gnu.so: undefined symbol: _ZN2at4_ops5zeros4callEN3c108ArrayRefINS2_6SymIntEEENS2_8optionalINS2_10ScalarTypeEEENS6_INS2_6LayoutEEENS6_INS2_6DeviceEEENS6_IbEE

SO I uninstall and install flashattention like the following:

pip uninstall flash-attn
pip install  --no-build-isolation  flash-attn==2.5.6  -U --force-reinstall

However this will uninstall the current torch and install torch '2.5.1+cu124' and still i have this issue again:

 import flash_attn_2_cuda as flash_attn_cuda
ImportError: /usr/local/lib/python3.10/dist-packages/flash_attn_2_cuda.cpython-310-x86_64-linux-gnu.so: undefined symbol: _ZN2at4_ops5zeros4callEN3c108ArrayRefINS2_6SymIntEEENS2_8optionalINS2_10ScalarTypeEEENS6_INS2_6LayoutEEENS6_INS2_6DeviceEEENS6_IbEE

So I can't install it!

tridao commented 2 days ago

You should