bowang-lab / scGPT

https://scgpt.readthedocs.io/en/latest/
MIT License
1k stars 195 forks source link

NameError: name 'FlashMHA' is not defined #182

Open lamasJose opened 6 months ago

lamasJose commented 6 months ago

Hi, Im doing the installation without any problem on a conda environment, but when I try to run the multiomics integration example I receive the error "name 'FlashMHA' is not defined", also, at the begining the warnings "/home/jmlamas/miniforge3/envs/scgpt2/lib/python3.10/site-packages/scgpt/model/model.py:21: UserWarning: flash_attn is not installed warnings.warn("flash_attn is not installed") /home/jmlamas/miniforge3/envs/scgpt2/lib/python3.10/site-packages/scgpt/model/multiomic_model.py:19: UserWarning: flash_attn is not installed warnings.warn("flash_attn is not installed")" However, doing pip list I obtain (scgpt2) [jmlamas@cluster1-head1 ~]$ pip list Package Version


absl-py 2.1.0 aiohttp 3.9.4 aiosignal 1.3.1 anndata 0.10.7 appdirs 1.4.4 array_api_compat 1.6 asttokens 2.4.1 async-timeout 4.0.3 attrs 23.2.0 cached-property 1.5.2 cell-gears 0.0.2 certifi 2024.2.2 charset-normalizer 3.3.2 chex 0.1.86 click 8.1.7 contextlib2 21.6.0 contourpy 1.2.1 cycler 0.12.1 datasets 2.18.0 dcor 0.6 decorator 5.1.1 Deprecated 1.2.14 dill 0.3.8 docker-pycreds 0.4.0 docrep 0.3.2 einops 0.7.0 et-xmlfile 1.1.0 etils 1.7.0 exceptiongroup 1.2.0 executing 2.0.1 filelock 3.13.4 flash_attn 1.0.4 flax 0.8.2 fonttools 4.51.0 frozenlist 1.4.1 fsspec 2024.2.0 gitdb 4.0.11 GitPython 3.1.43 h5py 3.11.0 huggingface-hub 0.22.2 idna 3.7 igraph 0.11.4 importlib_resources 6.4.0 ipython 8.23.0 jax 0.4.26 jaxlib 0.4.26 jedi 0.19.1 Jinja2 3.1.3 joblib 1.4.0 kiwisolver 1.4.5 legacy-api-wrap 1.4 leidenalg 0.10.2 lightning-utilities 0.11.2 llvmlite 0.42.0 markdown-it-py 3.0.0 MarkupSafe 2.1.5 matplotlib 3.8.4 matplotlib-inline 0.1.7 mdurl 0.1.2 ml_collections 0.1.1 ml-dtypes 0.4.0 mpmath 1.3.0 msgpack 1.0.8 mudata 0.2.3 multidict 6.0.5 multipledispatch 1.0.0 multiprocess 0.70.16 natsort 8.4.0 nest-asyncio 1.6.0 networkx 3.3 numba 0.59.1 numpy 1.26.4 numpyro 0.14.0 nvidia-cublas-cu12 12.1.3.1 nvidia-cuda-cupti-cu12 12.1.105 nvidia-cuda-nvrtc-cu12 12.1.105 nvidia-cuda-runtime-cu12 12.1.105 nvidia-cudnn-cu12 8.9.2.26 nvidia-cufft-cu12 11.0.2.54 nvidia-curand-cu12 10.3.2.106 nvidia-cusolver-cu12 11.4.5.107 nvidia-cusparse-cu12 12.1.0.106 nvidia-nccl-cu12 2.18.1 nvidia-nvjitlink-cu12 12.4.127 nvidia-nvtx-cu12 12.1.105 openpyxl 3.1.2 opt-einsum 3.3.0 optax 0.2.2 orbax 0.1.7 orbax-checkpoint 0.5.9 packaging 24.0 pandas 2.2.2 parso 0.8.4 patsy 0.5.6 pexpect 4.9.0 pillow 10.3.0 pip 24.0 prompt-toolkit 3.0.43 protobuf 4.25.3 psutil 5.9.8 ptyprocess 0.7.0 pure-eval 0.2.2 pyarrow 15.0.2 pyarrow-hotfix 0.6 pydot 2.0.0 Pygments 2.17.2 pynndescent 0.5.12 pyparsing 3.1.2 pyro-api 0.1.2 pyro-ppl 1.9.0 python-dateutil 2.9.0.post0 pytorch-lightning 1.9.5 pytz 2024.1 PyYAML 6.0.1 requests 2.31.0 rich 13.7.1 scanpy 1.10.1 scgpt 0.2.1 scib 1.1.5 scikit-learn 1.4.2 scikit-misc 0.3.1 scipy 1.13.0 scvi-tools 0.20.3 seaborn 0.13.2 sentry-sdk 1.45.0 session_info 1.0.0 setproctitle 1.3.3 setuptools 69.5.1 six 1.16.0 smmap 5.0.1 stack-data 0.6.3 statsmodels 0.14.1 stdlib-list 0.10.0 sympy 1.12 tensorstore 0.1.56 texttable 1.7.0 threadpoolctl 3.4.0 toolz 0.12.1 torch 2.1.2 torchdata 0.7.1 torchmetrics 1.3.2 torchtext 0.16.2 tqdm 4.66.2 traitlets 5.14.2 triton 2.1.0 typing_extensions 4.11.0 tzdata 2024.1 umap-learn 0.5.6 urllib3 2.2.1 wandb 0.16.6 wcwidth 0.2.13 wheel 0.43.0 wrapt 1.16.0 xxhash 3.4.1 yarl 1.9.4 zipp 3.18.1

So flash_attn is correctly installed. Anyone knows what is the problem?

Ragagnin commented 6 months ago

Im facing the same issue!

subercui commented 5 months ago

Hi, thank you for the question and sharing the your environment info. It looks like you have flash-attn 1.0.4 installed. So the warning at model.py:21 basically says the FlashMHA class is not imported. See here: https://github.com/bowang-lab/scGPT/blob/706526a76d547de4ed711fa028c99be5bdf6ad8a/scgpt/model/model.py#L13-L21

And in flash-attn 1.0.4 you should have the class, see here: https://github.com/Dao-AILab/flash-attention/blob/v1.0.4/flash_attn/flash_attention.py#L74

Therefore, I think this is still likely an installation issue. Could you try in your environment,

python -c "import flash_attn"
python -c "from flash_attn.flash_attention import FlashMHA"

These should tell you whether the package has been properly installed.

Ragagnin commented 5 months ago

In my case, I'm running on Colab, and apparently there is incompatibility between CUDA and flash_attn:

!python -c "import flash_attn"
!python -c "from flash_attn.flash_attention import FlashMHA"
Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/usr/local/lib/python3.10/dist-packages/flash_attn/flash_attention.py", line 7, in <module>
    from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
  File "/usr/local/lib/python3.10/dist-packages/flash_attn/flash_attn_interface.py", line 5, in <module>
    import flash_attn_cuda
ImportError: /usr/local/lib/python3.10/dist-packages/flash_attn_cuda.cpython-310-x86_64-linux-gnu.so: undefined symbol: _ZN3c104cuda20CUDACachingAllocator9allocatorE

Do you have any suggestions on how to resolve it?

litxiaoyao commented 3 weeks ago

I have encountered the same problem