Open jonathanasdf opened 9 months ago
getting same
Me too, can anyone solve this problem?
same
It's because of torch version change. nvcr pytorch 23.12 should work with flash-attn v2.4.0.post1 now. If you're using torch-nightly, we currently use torch-nightly 20231106 to compile the CUDA wheel, so if your torch-nightly version is close to that it should work.
I installed flash-attn v2.4.2 successfully with python3.10+cuda12.1+torch2.1.0+transformers4.36.2
here is my environment
# Name Version Build Channel
_libgcc_mutex 0.1 main
accelerate 0.25.0 <pip>
aiohttp 3.9.1 <pip>
aiosignal 1.3.1 <pip>
annotated-types 0.6.0 <pip>
anyio 4.2.0 <pip>
argon2-cffi 23.1.0 <pip>
argon2-cffi-bindings 21.2.0 <pip>
async-timeout 4.0.3 <pip>
attrs 23.2.0 <pip>
auto-gptq 0.6.0 <pip>
backoff 2.2.1 <pip>
beautifulsoup4 4.12.2 <pip>
bitsandbytes 0.41.0 <pip>
blas 1.0 mkl
bzip2 1.0.8 h7b6447c_0
ca-certificates 2023.12.12 h06a4308_0
certifi 2023.11.17 <pip>
cffi 1.16.0 <pip>
charset-normalizer 2.1.1 <pip>
click 8.1.7 <pip>
coloredlogs 15.0.1 <pip>
cuda-cudart 12.1.105 0 nvidia
cuda-cupti 12.1.105 0 nvidia
cuda-libraries 12.1.0 0 nvidia
cuda-nvrtc 12.1.105 0 nvidia
cuda-nvtx 12.1.105 0 nvidia
cuda-opencl 12.3.101 0 nvidia
cuda-runtime 12.1.0 0 nvidia
dataclasses-json 0.6.3 <pip>
datasets 2.16.1 <pip>
dill 0.3.7 <pip>
distro 1.9.0 <pip>
dropout-layer-norm 0.1 <pip>
einops 0.7.0 <pip>
elastic-transport 8.11.0 <pip>
elasticsearch 7.10.1 <pip>
environs 9.5.0 <pip>
exceptiongroup 1.2.0 <pip>
exrex 0.11.0 <pip>
fastllm 0.2.1 <pip>
fastllm-pytools 0.0.1 <pip>
filelock 3.13.1 <pip>
filelock 3.13.1 py310h06a4308_0
flash-attn 2.4.2 <pip>
frozenlist 1.4.1 <pip>
fsspec 2023.10.0 <pip>
fuzzywuzzy 0.18.0 <pip>
gekko 1.0.6 <pip>
greenlet 3.0.3 <pip>
grpcio 1.58.0 <pip>
h11 0.14.0 <pip>
httpcore 1.0.2 <pip>
httpx 0.26.0 <pip>
huggingface-hub 0.20.1 <pip>
humanfriendly 10.0 <pip>
idna 3.6 <pip>
intel-openmp 2021.4.0 h06a4308_3561
jinja2 3.1.2 py310h06a4308_0
Jinja2 3.1.2 <pip>
jiojio 1.2.5 <pip>
jionlp 1.5.6 <pip>
joblib 1.3.2 <pip>
jsonpatch 1.33 <pip>
jsonpointer 2.4 <pip>
jsonschema 4.20.0 <pip>
jsonschema-specifications 2023.12.1 <pip>
langchain 0.0.354 <pip>
langchain-community 0.0.8 <pip>
langchain-core 0.1.5 <pip>
langsmith 0.0.77 <pip>
ld_impl_linux-64 2.38 h1181459_1
Levenshtein 0.23.0 <pip>
libcublas 12.1.0.26 0 nvidia
libcufft 11.0.2.4 0 nvidia
libcufile 1.8.1.2 0 nvidia
libcurand 10.3.4.101 0 nvidia
libcusolver 11.4.4.55 0 nvidia
libcusparse 12.0.2.55 0 nvidia
libffi 3.3 he6710b0_2
libgcc-ng 9.1.0 hdf63c60_0
libnpp 12.0.2.50 0 nvidia
libnvjitlink 12.1.105 0 nvidia
libnvjpeg 12.1.1.14 0 nvidia
libstdcxx-ng 9.1.0 hdf63c60_0
libuuid 1.0.3 h7f8727e_2
llvm-openmp 14.0.6 h9e868ea_0
markdown-it-py 3.0.0 <pip>
markupsafe 2.1.1 py310h7f8727e_0
MarkupSafe 2.1.1 <pip>
marshmallow 3.20.1 <pip>
mdurl 0.1.2 <pip>
minio 7.2.3 <pip>
mkl 2021.4.0 h06a4308_640
mkl-fft 1.3.1 <pip>
mkl-random 1.2.2 <pip>
mkl-service 2.4.0 py310h7f8727e_0
mkl-service 2.4.0 <pip>
mkl_fft 1.3.1 py310hd6ae3a3_0
mkl_random 1.2.2 py310h00e6091_0
mpmath 1.3.0 <pip>
mpmath 1.3.0 py310h06a4308_0
multidict 6.0.4 <pip>
multiprocess 0.70.15 <pip>
mypy-extensions 1.0.0 <pip>
ncurses 6.3 h7f8727e_2
networkx 3.1 <pip>
networkx 3.1 py310h06a4308_0
ninja 1.11.1.1 <pip>
nltk 3.8.1 <pip>
numpy 1.22.3 py310hfa59a62_0
numpy 1.26.3 <pip>
numpy-base 1.22.3 py310h9585f30_0
nvidia-cublas-cu12 12.1.3.1 <pip>
nvidia-cuda-cupti-cu12 12.1.105 <pip>
nvidia-cuda-nvrtc-cu12 12.1.105 <pip>
nvidia-cuda-runtime-cu12 12.1.105 <pip>
nvidia-cudnn-cu12 8.9.2.26 <pip>
nvidia-cufft-cu12 11.0.2.54 <pip>
nvidia-curand-cu12 10.3.2.106 <pip>
nvidia-cusolver-cu12 11.4.5.107 <pip>
nvidia-cusparse-cu12 12.1.0.106 <pip>
nvidia-nccl-cu12 2.18.1 <pip>
nvidia-nvjitlink-cu12 12.3.101 <pip>
nvidia-nvtx-cu12 12.1.105 <pip>
openai 1.6.1 <pip>
openssl 1.1.1w h7f8727e_0
optimum 1.16.1 <pip>
packaging 23.2 <pip>
pandas 2.1.4 <pip>
peft 0.7.1 <pip>
pillow 10.2.0 <pip>
pip 23.3.1 <pip>
pip 23.3.1 py310h06a4308_0
protobuf 4.25.1 <pip>
psutil 5.9.7 <pip>
pyarrow 14.0.2 <pip>
pyarrow-hotfix 0.6 <pip>
pycparser 2.21 <pip>
pycryptodome 3.19.1 <pip>
pydantic 2.5.3 <pip>
pydantic_core 2.14.6 <pip>
Pygments 2.17.2 <pip>
pymilvus 2.3.5 <pip>
PyMySQL 1.1.0 <pip>
python 3.10.4 h12debd9_0
python-dateutil 2.8.2 <pip>
python-dotenv 1.0.0 <pip>
python-Levenshtein 0.23.0 <pip>
pytorch 2.1.0 py3.10_cuda12.1_cudnn8.9.2_0 pytorch
pytorch-cuda 12.1 ha16c6d3_5 pytorch
pytorch-mutex 1.0 cuda pytorch
pytz 2023.3.post1 <pip>
pyyaml 6.0 py310h7f8727e_0
PyYAML 6.0 <pip>
rapidfuzz 3.6.1 <pip>
readline 8.1.2 h7f8727e_1
referencing 0.32.0 <pip>
regex 2023.12.25 <pip>
requests 2.28.1 <pip>
rich 13.7.0 <pip>
rouge 1.0.1 <pip>
rpds-py 0.16.2 <pip>
safetensors 0.4.1 <pip>
scikit-learn 1.3.2 <pip>
scipy 1.11.4 <pip>
sentence-transformers 2.2.2 <pip>
sentencepiece 0.1.99 <pip>
setuptools 68.2.2 py310h06a4308_0
setuptools 68.2.2 <pip>
six 1.16.0 pyhd3eb1b0_1
sniffio 1.3.0 <pip>
soupsieve 2.5 <pip>
SQLAlchemy 2.0.25 <pip>
sqlite 3.38.5 hc218d9a_0
sympy 1.12 <pip>
sympy 1.12 py310h06a4308_0
tabulate 0.9.0 <pip>
tenacity 8.2.3 <pip>
threadpoolctl 3.2.0 <pip>
tiktoken 0.5.2 <pip>
tk 8.6.12 h1ccaba5_0
tokenizers 0.15.0 <pip>
torch 2.1.0+cu121 <pip>
torchaudio 2.1.0 py310_cu121 pytorch
torchaudio 2.1.0 <pip>
torchtriton 2.1.0 py310 pytorch
torchvision 0.16.0+cu121 <pip>
tornado 6.4 <pip>
tqdm 4.66.1 <pip>
transformers 4.36.2 <pip>
transformers-stream-generator 0.0.4 <pip>
triton 2.1.0 <pip>
typing-inspect 0.9.0 <pip>
typing_extensions 4.7.1 py310h06a4308_0
typing_extensions 4.7.1 <pip>
tzdata 2023c h04d1e81_0
tzdata 2023.4 <pip>
ujson 5.9.0 <pip>
urllib3 1.26.18 <pip>
wheel 0.41.2 py310h06a4308_0
wheel 0.41.2 <pip>
xxhash 3.4.1 <pip>
xz 5.2.5 h7f8727e_1
yaml 0.2.5 h7b6447c_0
yarl 1.9.4 <pip>
zipfile36 0.1.3 <pip>
zlib 1.2.12 h7f8727e_2
same problem on torch 2.1.0+cuda 12.1
:
ImportError: /usr/local/lib/python3.8/dist-packages/flash_attn_2_cuda.cpython-38-x86_64-linux-gnu.so: undefined symbol: _ZN3c104cuda9SetDeviceEi
same, torch2.2.0+cuda11.8: undefined symbol: _ZN2at4_ops9_pad_enum4callERKNS_6TensorEN3c108ArrayRefINS5_6SymIntEEElNS5_8optionalIdEE
Python 3.10.13 | packaged by conda-forge | (main, Dec 23 2023, 15:36:39) [GCC 12.3.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import flash_attn
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/jon/miniconda3/envs/h2ogpt/lib/python3.10/site-packages/flash_attn/__init__.py", line 3, in <module>
from flash_attn.flash_attn_interface import (
File "/home/jon/miniconda3/envs/h2ogpt/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 8, in <module>
import flash_attn_2_cuda as flash_attn_cuda
ImportError: /home/jon/miniconda3/envs/h2ogpt/lib/python3.10/site-packages/flash_attn_2_cuda.cpython-310-x86_64-linux-gnu.so: undefined symbol: _ZN2at4_ops9_pad_enum4callERKNS_6TensorEN3c108ArrayRefINS5_6SymIntEEElNS5_8optionalIdEE
>>>
started getting this, no idea why. Nothing mentioned works. Tried cuda11.8 and 12.1, all consistent install, never works.
export CUDA_HOME=/usr/local/cuda-12.1
export PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cu121"
pip install --upgrade pip
pip install flash-attn==2.4.2 --no-build-isolation --no-cache-dir
Same issue with latest version of flash-attn:
(h2ogpt) jon@gpu:~$ python
Python 3.10.13 | packaged by conda-forge | (main, Dec 23 2023, 15:36:39) [GCC 12.3.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import flash_attn
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/jon/miniconda3/envs/h2ogpt/lib/python3.10/site-packages/flash_attn/__init__.py", line 3, in <module>
from flash_attn.flash_attn_interface import (
File "/home/jon/miniconda3/envs/h2ogpt/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 10, in <module>
import flash_attn_2_cuda as flash_attn_cuda
ImportError: /home/jon/miniconda3/envs/h2ogpt/lib/python3.10/site-packages/flash_attn_2_cuda.cpython-310-x86_64-linux-gnu.so: undefined symbol: _ZN2at4_ops5zeros4callEN3c108ArrayRefINS2_6SymIntEEENS2_8optionalINS2_10ScalarTypeEEENS6_INS2_6LayoutEEENS6_INS2_6DeviceEEENS6_IbEE
>>>
Ah, seems to be issue with torch 2.2.0 and flash_attn.
I have the same issue, has someone solved?
cuda: 12.1 torch: 2.2.0+cu121 flash-attn: 2.5.7 work
I'm using cuda 12.1, pytorch nightly 2.2.0+cu121, flash attention from source (
pip install git+https://github.com/Dao-AILab/flash-attention.git@92dd570 --no-build-isolation
), any idea what I can do to debug this?