axolotl-ai-cloud / axolotl

Go ahead and axolotl questions
https://axolotl-ai-cloud.github.io/axolotl/
Apache License 2.0
7.36k stars 793 forks source link

import flash_attn_2_cuda as flash_attn_cuda fails #1588

Open Dyke-F opened 3 months ago

Dyke-F commented 3 months ago

Please check that this issue hasn't been reported before.

Expected Behavior

should train a model ...

I basically tried every installation setup, using conda, pip, different versions of torch etc ....

Also tried ALL available solutions that are already reported

Current behaviour

ImportError: /home/dyfe751f/.conda/envs/med_llm_venv/lib/python3.11/site-packages/flash_attn_2_cuda.cpython-311-x86_64-linux-gnu.so: undefined symbol: _ZN3c104cuda9SetDeviceEi
import flash_attn_2_cuda as flash_attn_cuda

(med_llm_venv) dyfe751f@login1.alpha axolotl$ tmux a

File "/beegfs/.global1/ws/dyfe751f-MEDICALLLMTRAIN/axolotl/src/axolotl/core/trainer_builder.py", line 40, in [38/610] from axolotl.utils.callbacks import (from axolotl.utils.callbacks import (

File "/beegfs/.global1/ws/dyfe751f-MEDICALLLMTRAIN/axolotl/src/axolotl/utils/callbacks/init.py", line 18, in
File "/beegfs/.global1/ws/dyfe751f-MEDICALLLMTRAIN/axolotl/src/axolotl/utils/callbacks/init.py", line 18, in
from axolotl.utils.callbacks import (
File "/beegfs/.global1/ws/dyfe751f-MEDICALLLMTRAIN/axolotl/src/axolotl/utils/callbacks/init.py", line 18, in
from optimum.bettertransformer import BetterTransformerfrom optimum.bettertransformer import BetterTransformer

File "/home/dyfe751f/.conda/envs/med_llm_venv/lib/python3.11/site-packages/optimum/bettertransformer/init.py", line 14, in
File "/home/dyfe751f/.conda/envs/med_llm_venv/lib/python3.11/site-packages/optimum/bettertransformer/init.py", line 14, in
from optimum.bettertransformer import BetterTransformer
File "/home/dyfe751f/.conda/envs/med_llm_venv/lib/python3.11/site-packages/optimum/bettertransformer/init.py", line 14, in
from .models import BetterTransformerManager
File "/home/dyfe751f/.conda/envs/med_llm_venv/lib/python3.11/site-packages/optimum/bettertransformer/models/init.py", line 17, in
from .models import BetterTransformerManager
File "/home/dyfe751f/.conda/envs/med_llm_venv/lib/python3.11/site-packages/optimum/bettertransformer/models/init.py", line 17, in
from .models import BetterTransformerManager
File "/home/dyfe751f/.conda/envs/med_llm_venv/lib/python3.11/site-packages/optimum/bettertransformer/models/init.py", line 17, in
from .decoder_models import (
File "/home/dyfe751f/.conda/envs/med_llm_venv/lib/python3.11/site-packages/optimum/bettertransformer/models/decoder_models.py", line 18, in
from .decoder_models import (
File "/home/dyfe751f/.conda/envs/med_llm_venv/lib/python3.11/site-packages/optimum/bettertransformer/models/decoder_models.py", line 18, in
from .decoder_models import (
File "/home/dyfe751f/.conda/envs/med_llm_venv/lib/python3.11/site-packages/optimum/bettertransformer/models/decoder_models.py", line 18, in
from transformers.models.bart.modeling_bart import BartAttention
File "/home/dyfe751f/.conda/envs/med_llm_venv/lib/python3.11/site-packages/transformers/models/bart/modeling_bart.py", line 58, in
from transformers.models.bart.modeling_bart import BartAttention
File "/home/dyfe751f/.conda/envs/med_llm_venv/lib/python3.11/site-packages/transformers/models/bart/modeling_bart.py", line 58, in
from transformers.models.bart.modeling_bart import BartAttention
File "/home/dyfe751f/.conda/envs/med_llm_venv/lib/python3.11/site-packages/transformers/models/bart/modeling_bart.py", line 58, in
from flash_attn import flash_attn_func, flash_attn_varlen_func
File "/home/dyfe751f/.conda/envs/med_llm_venv/lib/python3.11/site-packages/flash_attn/init.py", line 3, in
from flash_attn import flash_attn_func, flash_attn_varlen_func
File "/home/dyfe751f/.conda/envs/med_llm_venv/lib/python3.11/site-packages/flash_attn/init.py", line 3, in
from flash_attn import flash_attn_func, flash_attn_varlen_func
File "/home/dyfe751f/.conda/envs/med_llm_venv/lib/python3.11/site-packages/flash_attn/init.py", line 3, in
from flash_attn.flash_attn_interface import (
File "/home/dyfe751f/.conda/envs/med_llm_venv/lib/python3.11/site-packages/flash_attn/flash_attn_interface.py", line 10, in
from flash_attn.flash_attn_interface import ( File "/home/dyfe751f/.conda/envs/med_llm_venv/lib/python3.11/site-packages/flash_attn/flash_attn_interface.py", line 10, in
from flash_attn.flash_attn_interface import ( File "/home/dyfe751f/.conda/envs/med_llm_venv/lib/python3.11/site-packages/flash_attn/flash_attn_interface.py", line 10, in
import flash_attn_2_cuda as flash_attn_cuda ImportError: /home/dyfe751f/.conda/envs/med_llm_venv/lib/python3.11/site-packages/flash_attn_2_cuda.cpython-311-x86_64-linux-gnu.so: undefined symbol: _ZN3c104cuda9SetDeviceEi
import flash_attn_2_cuda as flash_attn_cuda ImportError: /home/dyfe751f/.conda/envs/med_llm_venv/lib/python3.11/site-packages/flash_attn_2_cuda.cpython-311-x86_64-linux-gnu.so: undefined symbol: _ZN3c104cuda9SetDeviceEi
import flash_attn_2_cuda as flash_attn_cuda ImportError: /home/dyfe751f/.conda/envs/med_llm_venv/lib/python3.11/site-packages/flash_attn_2_cuda.cpython-311-x86_64-linux-gnu.so: undefined symbol: _ZN3c104cuda9SetDeviceEi
E0503 20:53:54.127000 140556082500608 torch/distributed/elastic/multiprocessing/api.py:826] failed (exitcode: 1) local_rank: 0 (pid: 194756) of binary: /home/dyfe751f/.conda/envs/med_llm_venv/bin/python
Traceback (most recent call last): File "/home/dyfe751f/.conda/envs/med_llm_venv/bin/accelerate", line 8, in sys.exit(main()) ^^^^^^ File "/home/dyfe751f/.conda/envs/med_llm_venv/lib/python3.11/site-packages/accelerate/commands/accelerate_cli.py", line 46, in main
args.func(args) File "/home/dyfe751f/.conda/envs/med_llm_venv/lib/python3.11/site-packages/accelerate/commands/launch.py", line 1048, in launch_command
multi_gpu_launcher(args) File "/home/dyfe751f/.conda/envs/med_llm_venv/lib/python3.11/site-packages/accelerate/commands/launch.py", line 702, in multi_gpu_launcher distrib_run.run(args) File "/home/dyfe751f/.conda/envs/med_llm_venv/lib/python3.11/site-packages/torch/distributed/run.py", line 870, in run elastic_launch( File "/home/dyfe751f/.conda/envs/med_llm_venv/lib/python3.11/site-packages/torch/distributed/launcher/api.py", line 132, in call return launch_agent(self._config, self._entrypoint, list(args)) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/dyfe751f/.conda/envs/med_llm_venv/lib/python3.11/site-packages/torch/distributed/launcher/api.py", line 263, in launch_agent

Steps to reproduce

install as in the readme then try all suggested fixes

Config yaml

no changes to accelerate launch -m axolotl.cli.train examples/openllama-3b/lora.yml

Possible solution

No response

Which Operating Systems are you using?

Python Version

3.11

axolotl branch-commit

main

Acknowledgements

webbigdata-jp commented 3 months ago

Uninstall flash-attention once

pip uninstall flash-attn

Install flash-attention from source

git clone https://github.com/Dao-AILab/flash-attention
cd flash-attention
python setup.py install

works for me.

https://github.com/Dao-AILab/flash-attention/issues/931

justin-echternach commented 3 months ago

The fix above worked for me as well.