Dao-AILab / flash-attention

Fast and memory-efficient exact attention
BSD 3-Clause "New" or "Revised" License
13.63k stars 1.25k forks source link

Undefined symbol #723

Open jonathanasdf opened 9 months ago

jonathanasdf commented 9 months ago
RuntimeError: Failed to import transformers.models.llama.modeling_llama because of the following error (look up to see its traceback):
/opt/venv/lib/python3.11/site-packages/flash_attn_2_cuda.cpython-311-x86_64-linux-gnu.so: undefined symbol: _ZN2at4_ops9_pad_enum4callERKNS_6TensorEN3c108ArrayRefINS5_6SymIntEEElNS5_8optionalIdEE

I'm using cuda 12.1, pytorch nightly 2.2.0+cu121, flash attention from source (pip install git+https://github.com/Dao-AILab/flash-attention.git@92dd570 --no-build-isolation), any idea what I can do to debug this?

kandeldeepak46 commented 9 months ago

getting same

sleeper1023 commented 9 months ago

Me too, can anyone solve this problem?

tcapelle commented 9 months ago

same

tridao commented 9 months ago

It's because of torch version change. nvcr pytorch 23.12 should work with flash-attn v2.4.0.post1 now. If you're using torch-nightly, we currently use torch-nightly 20231106 to compile the CUDA wheel, so if your torch-nightly version is close to that it should work.

tianyunzqs commented 9 months ago

I installed flash-attn v2.4.2 successfully with python3.10+cuda12.1+torch2.1.0+transformers4.36.2
here is my environment

# Name                    Version                   Build  Channel
_libgcc_mutex             0.1                        main  
accelerate                0.25.0                    <pip>
aiohttp                   3.9.1                     <pip>
aiosignal                 1.3.1                     <pip>
annotated-types           0.6.0                     <pip>
anyio                     4.2.0                     <pip>
argon2-cffi               23.1.0                    <pip>
argon2-cffi-bindings      21.2.0                    <pip>
async-timeout             4.0.3                     <pip>
attrs                     23.2.0                    <pip>
auto-gptq                 0.6.0                     <pip>
backoff                   2.2.1                     <pip>
beautifulsoup4            4.12.2                    <pip>
bitsandbytes              0.41.0                    <pip>
blas                      1.0                         mkl  
bzip2                     1.0.8                h7b6447c_0  
ca-certificates           2023.12.12           h06a4308_0  
certifi                   2023.11.17                <pip>
cffi                      1.16.0                    <pip>
charset-normalizer        2.1.1                     <pip>
click                     8.1.7                     <pip>
coloredlogs               15.0.1                    <pip>
cuda-cudart               12.1.105                      0    nvidia
cuda-cupti                12.1.105                      0    nvidia
cuda-libraries            12.1.0                        0    nvidia
cuda-nvrtc                12.1.105                      0    nvidia
cuda-nvtx                 12.1.105                      0    nvidia
cuda-opencl               12.3.101                      0    nvidia
cuda-runtime              12.1.0                        0    nvidia
dataclasses-json          0.6.3                     <pip>
datasets                  2.16.1                    <pip>
dill                      0.3.7                     <pip>
distro                    1.9.0                     <pip>
dropout-layer-norm        0.1                       <pip>
einops                    0.7.0                     <pip>
elastic-transport         8.11.0                    <pip>
elasticsearch             7.10.1                    <pip>
environs                  9.5.0                     <pip>
exceptiongroup            1.2.0                     <pip>
exrex                     0.11.0                    <pip>
fastllm                   0.2.1                     <pip>
fastllm-pytools           0.0.1                     <pip>
filelock                  3.13.1                    <pip>
filelock                  3.13.1          py310h06a4308_0  
flash-attn                2.4.2                     <pip>
frozenlist                1.4.1                     <pip>
fsspec                    2023.10.0                 <pip>
fuzzywuzzy                0.18.0                    <pip>
gekko                     1.0.6                     <pip>
greenlet                  3.0.3                     <pip>
grpcio                    1.58.0                    <pip>
h11                       0.14.0                    <pip>
httpcore                  1.0.2                     <pip>
httpx                     0.26.0                    <pip>
huggingface-hub           0.20.1                    <pip>
humanfriendly             10.0                      <pip>
idna                      3.6                       <pip>
intel-openmp              2021.4.0          h06a4308_3561  
jinja2                    3.1.2           py310h06a4308_0  
Jinja2                    3.1.2                     <pip>
jiojio                    1.2.5                     <pip>
jionlp                    1.5.6                     <pip>
joblib                    1.3.2                     <pip>
jsonpatch                 1.33                      <pip>
jsonpointer               2.4                       <pip>
jsonschema                4.20.0                    <pip>
jsonschema-specifications 2023.12.1                 <pip>
langchain                 0.0.354                   <pip>
langchain-community       0.0.8                     <pip>
langchain-core            0.1.5                     <pip>
langsmith                 0.0.77                    <pip>
ld_impl_linux-64          2.38                 h1181459_1  
Levenshtein               0.23.0                    <pip>
libcublas                 12.1.0.26                     0    nvidia
libcufft                  11.0.2.4                      0    nvidia
libcufile                 1.8.1.2                       0    nvidia
libcurand                 10.3.4.101                    0    nvidia
libcusolver               11.4.4.55                     0    nvidia
libcusparse               12.0.2.55                     0    nvidia
libffi                    3.3                  he6710b0_2  
libgcc-ng                 9.1.0                hdf63c60_0  
libnpp                    12.0.2.50                     0    nvidia
libnvjitlink              12.1.105                      0    nvidia
libnvjpeg                 12.1.1.14                     0    nvidia
libstdcxx-ng              9.1.0                hdf63c60_0  
libuuid                   1.0.3                h7f8727e_2  
llvm-openmp               14.0.6               h9e868ea_0  
markdown-it-py            3.0.0                     <pip>
markupsafe                2.1.1           py310h7f8727e_0  
MarkupSafe                2.1.1                     <pip>
marshmallow               3.20.1                    <pip>
mdurl                     0.1.2                     <pip>
minio                     7.2.3                     <pip>
mkl                       2021.4.0           h06a4308_640  
mkl-fft                   1.3.1                     <pip>
mkl-random                1.2.2                     <pip>
mkl-service               2.4.0           py310h7f8727e_0  
mkl-service               2.4.0                     <pip>
mkl_fft                   1.3.1           py310hd6ae3a3_0  
mkl_random                1.2.2           py310h00e6091_0  
mpmath                    1.3.0                     <pip>
mpmath                    1.3.0           py310h06a4308_0  
multidict                 6.0.4                     <pip>
multiprocess              0.70.15                   <pip>
mypy-extensions           1.0.0                     <pip>
ncurses                   6.3                  h7f8727e_2  
networkx                  3.1                       <pip>
networkx                  3.1             py310h06a4308_0  
ninja                     1.11.1.1                  <pip>
nltk                      3.8.1                     <pip>
numpy                     1.22.3          py310hfa59a62_0  
numpy                     1.26.3                    <pip>
numpy-base                1.22.3          py310h9585f30_0  
nvidia-cublas-cu12        12.1.3.1                  <pip>
nvidia-cuda-cupti-cu12    12.1.105                  <pip>
nvidia-cuda-nvrtc-cu12    12.1.105                  <pip>
nvidia-cuda-runtime-cu12  12.1.105                  <pip>
nvidia-cudnn-cu12         8.9.2.26                  <pip>
nvidia-cufft-cu12         11.0.2.54                 <pip>
nvidia-curand-cu12        10.3.2.106                <pip>
nvidia-cusolver-cu12      11.4.5.107                <pip>
nvidia-cusparse-cu12      12.1.0.106                <pip>
nvidia-nccl-cu12          2.18.1                    <pip>
nvidia-nvjitlink-cu12     12.3.101                  <pip>
nvidia-nvtx-cu12          12.1.105                  <pip>
openai                    1.6.1                     <pip>
openssl                   1.1.1w               h7f8727e_0  
optimum                   1.16.1                    <pip>
packaging                 23.2                      <pip>
pandas                    2.1.4                     <pip>
peft                      0.7.1                     <pip>
pillow                    10.2.0                    <pip>
pip                       23.3.1                    <pip>
pip                       23.3.1          py310h06a4308_0  
protobuf                  4.25.1                    <pip>
psutil                    5.9.7                     <pip>
pyarrow                   14.0.2                    <pip>
pyarrow-hotfix            0.6                       <pip>
pycparser                 2.21                      <pip>
pycryptodome              3.19.1                    <pip>
pydantic                  2.5.3                     <pip>
pydantic_core             2.14.6                    <pip>
Pygments                  2.17.2                    <pip>
pymilvus                  2.3.5                     <pip>
PyMySQL                   1.1.0                     <pip>
python                    3.10.4               h12debd9_0  
python-dateutil           2.8.2                     <pip>
python-dotenv             1.0.0                     <pip>
python-Levenshtein        0.23.0                    <pip>
pytorch                   2.1.0           py3.10_cuda12.1_cudnn8.9.2_0    pytorch
pytorch-cuda              12.1                 ha16c6d3_5    pytorch
pytorch-mutex             1.0                        cuda    pytorch
pytz                      2023.3.post1              <pip>
pyyaml                    6.0             py310h7f8727e_0  
PyYAML                    6.0                       <pip>
rapidfuzz                 3.6.1                     <pip>
readline                  8.1.2                h7f8727e_1  
referencing               0.32.0                    <pip>
regex                     2023.12.25                <pip>
requests                  2.28.1                    <pip>
rich                      13.7.0                    <pip>
rouge                     1.0.1                     <pip>
rpds-py                   0.16.2                    <pip>
safetensors               0.4.1                     <pip>
scikit-learn              1.3.2                     <pip>
scipy                     1.11.4                    <pip>
sentence-transformers     2.2.2                     <pip>
sentencepiece             0.1.99                    <pip>
setuptools                68.2.2          py310h06a4308_0  
setuptools                68.2.2                    <pip>
six                       1.16.0             pyhd3eb1b0_1  
sniffio                   1.3.0                     <pip>
soupsieve                 2.5                       <pip>
SQLAlchemy                2.0.25                    <pip>
sqlite                    3.38.5               hc218d9a_0  
sympy                     1.12                      <pip>
sympy                     1.12            py310h06a4308_0  
tabulate                  0.9.0                     <pip>
tenacity                  8.2.3                     <pip>
threadpoolctl             3.2.0                     <pip>
tiktoken                  0.5.2                     <pip>
tk                        8.6.12               h1ccaba5_0  
tokenizers                0.15.0                    <pip>
torch                     2.1.0+cu121               <pip>
torchaudio                2.1.0               py310_cu121    pytorch
torchaudio                2.1.0                     <pip>
torchtriton               2.1.0                     py310    pytorch
torchvision               0.16.0+cu121              <pip>
tornado                   6.4                       <pip>
tqdm                      4.66.1                    <pip>
transformers              4.36.2                    <pip>
transformers-stream-generator 0.0.4                     <pip>
triton                    2.1.0                     <pip>
typing-inspect            0.9.0                     <pip>
typing_extensions         4.7.1           py310h06a4308_0  
typing_extensions         4.7.1                     <pip>
tzdata                    2023c                h04d1e81_0  
tzdata                    2023.4                    <pip>
ujson                     5.9.0                     <pip>
urllib3                   1.26.18                   <pip>
wheel                     0.41.2          py310h06a4308_0  
wheel                     0.41.2                    <pip>
xxhash                    3.4.1                     <pip>
xz                        5.2.5                h7f8727e_1  
yaml                      0.2.5                h7b6447c_0  
yarl                      1.9.4                     <pip>
zipfile36                 0.1.3                     <pip>
zlib                      1.2.12               h7f8727e_2  
Eikor commented 8 months ago

same problem on torch 2.1.0+cuda 12.1: ImportError: /usr/local/lib/python3.8/dist-packages/flash_attn_2_cuda.cpython-38-x86_64-linux-gnu.so: undefined symbol: _ZN3c104cuda9SetDeviceEi

ArlanCooper commented 8 months ago

same, torch2.2.0+cuda11.8: undefined symbol: _ZN2at4_ops9_pad_enum4callERKNS_6TensorEN3c108ArrayRefINS5_6SymIntEEElNS5_8optionalIdEE

pseudotensor commented 8 months ago
Python 3.10.13 | packaged by conda-forge | (main, Dec 23 2023, 15:36:39) [GCC 12.3.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import flash_attn
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/jon/miniconda3/envs/h2ogpt/lib/python3.10/site-packages/flash_attn/__init__.py", line 3, in <module>
    from flash_attn.flash_attn_interface import (
  File "/home/jon/miniconda3/envs/h2ogpt/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 8, in <module>
    import flash_attn_2_cuda as flash_attn_cuda
ImportError: /home/jon/miniconda3/envs/h2ogpt/lib/python3.10/site-packages/flash_attn_2_cuda.cpython-310-x86_64-linux-gnu.so: undefined symbol: _ZN2at4_ops9_pad_enum4callERKNS_6TensorEN3c108ArrayRefINS5_6SymIntEEElNS5_8optionalIdEE
>>> 

started getting this, no idea why. Nothing mentioned works. Tried cuda11.8 and 12.1, all consistent install, never works.

    export CUDA_HOME=/usr/local/cuda-12.1
    export PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cu121"
    pip install --upgrade pip
    pip install flash-attn==2.4.2 --no-build-isolation --no-cache-dir
pseudotensor commented 8 months ago

Same issue with latest version of flash-attn:

(h2ogpt) jon@gpu:~$ python
Python 3.10.13 | packaged by conda-forge | (main, Dec 23 2023, 15:36:39) [GCC 12.3.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import flash_attn
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/jon/miniconda3/envs/h2ogpt/lib/python3.10/site-packages/flash_attn/__init__.py", line 3, in <module>
    from flash_attn.flash_attn_interface import (
  File "/home/jon/miniconda3/envs/h2ogpt/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 10, in <module>
    import flash_attn_2_cuda as flash_attn_cuda
ImportError: /home/jon/miniconda3/envs/h2ogpt/lib/python3.10/site-packages/flash_attn_2_cuda.cpython-310-x86_64-linux-gnu.so: undefined symbol: _ZN2at4_ops5zeros4callEN3c108ArrayRefINS2_6SymIntEEENS2_8optionalINS2_10ScalarTypeEEENS6_INS2_6LayoutEEENS6_INS2_6DeviceEEENS6_IbEE
>>> 
pseudotensor commented 8 months ago

Ah, seems to be issue with torch 2.2.0 and flash_attn.

caoxu915683474 commented 6 months ago

I have the same issue, has someone solved?

robinsonmd commented 5 months ago

cuda: 12.1 torch: 2.2.0+cu121 flash-attn: 2.5.7 work