YoshitakaMo / localcolabfold

ColabFold on your local PC
MIT License
563 stars 129 forks source link

ModuleNotFoundError: No module named 'jax.extend' related to #209, #210 #224

Closed ZubuNoShoshinsha closed 6 months ago

ZubuNoShoshinsha commented 6 months ago

Hello,

My question is related to #209 and #210 My environment is...

Wsl2 OS: Ubuntu 22.04.4 GCC: 11.4.0 CUDA: 12.1 GPU: RTX 4090 LocalColabFold Ver. 1.5.5

As instructed in #209 , I checked if GPU was recognized and it was not. So, I dongraded jax and jaxlib to jax 0.4.7 jaxlib0.4.7+cuda11.cudnn82 as instructed in #209 .

And then I checked again using $ /path/to/your/localcolabfold/colabfold-conda/bin/python3.10

import jax print(jax.local_devices()[0].platform)

and "gpu" was returned.

Then, I run the localcolabfold. But, this error message popped up and stopped like below

2024-04-01 15:14:35,452 Running colabfold 1.5.5 (61df3b853140ca79dbdf64349824beb14364ebfd) 2024-04-01 15:14:36,006 Running on GPU Traceback (most recent call last): File "/mnt/d/Alphafold/localcolabfold/colabfold-conda/bin/colabfold_batch", line 8, in sys.exit(main()) File "/mnt/d/AlphaFold/localcolabfold/colabfold-conda/lib/python3.10/site-packages/colabfold/batch.py", line 2037, in main run( File "/mnt/d/AlphaFold/localcolabfold/colabfold-conda/lib/python3.10/site-packages/colabfold/batch.py", line 1292, in run from colabfold.alphafold.models import load_models_and_params File "/mnt/d/AlphaFold/localcolabfold/colabfold-conda/lib/python3.10/site-packages/colabfold/alphafold/models.py", line 4, in import haiku File "/mnt/d/AlphaFold/localcolabfold/colabfold-conda/lib/python3.10/site-packages/haiku/init.py", line 20, in from haiku import experimental File "/mnt/d/AlphaFold/localcolabfold/colabfold-conda/lib/python3.10/site-packages/haiku/experimental/init.py", line 34, in from haiku._src.dot import abstract_to_dot File "/mnt/d/AlphaFold/localcolabfold/colabfold-conda/lib/python3.10/site-packages/haiku/_src/dot.py", line 29, in from jax.extend import linear_util as lu ModuleNotFoundError: No module named 'jax.extend'

It would be helpful if there would be any instruction for solving this issue.

YoshitakaMo commented 6 months ago

I suspect that the issue lies in the version of the dm-haiku module being 0.0.11 or later. In my environment:

$ localcolabfold/colabfold-conda/bin/python3.10 -m pip list

jax                          0.4.23
jaxlib                       0.4.23+cuda11.cudnn86
chex                         0.1.85
dm-haiku                     0.0.10

If CUDA 12.1 is installed, these versions should be fine.

Please set your dm-haiku to version 0.0.10. Otherwise, you may encounter the error ModuleNotFoundError: No module named 'jax.extend'

ZubuNoShoshinsha commented 6 months ago

Thank you for your suggestion. (I just noticed your response) Actually my dm-haiku was 0.0.12, so I down graded to 0.0.10 as you suggested.

And I ran localcolabfold 1.5.5. So, my environment is now jax 0.4.7 jaxlib 0.4.7+cuda11.cudnn82 chex 0.1.82 dm-haiku 0.0.10

No more " ModuleNotFoundError: No module named 'jax.extend' ", but now new message showed up and the program stopped.

" Could not predict ProteinA. Not Enough GPU memory? FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details. "

Is there any more suggestion to solve this??

ZubuNoShoshinsha commented 6 months ago

my nvidia-smi result nvidia-smi

YoshitakaMo commented 6 months ago

If you are using WSL2, did you turn on the settings shown in https://github.com/YoshitakaMo/localcolabfold?tab=readme-ov-file#for-wsl2-in-windows ? Unfortunately, I can't figure out the cause because I don't have a WSL2 environment.

ZubuNoShoshinsha commented 6 months ago

Yes, I did. I restart wsl2 and tried another shot, but it didn't work well.

Thank you though.

ZubuNoShoshinsha commented 6 months ago

I wonder... when I downgraded dm-haiku, the message said " colabfold 1.5.5 requires dm-haiku<0.013, >=0.0.12, but you have dm-haiku 0.0.10 which is incompatible. " Is it fine to run colabfold appropriately?

ZubuNoShoshinsha commented 6 months ago

Finally,

I might have found the solution.

I downgraded " nvidia-cudnn-cu11 " by doing this command from 9.0.0.312 to 8.5.0.96 .

pip install --upgrade nvidia-cudnn-cu11==8.5.0.96

I ran the localcolabfold and it processed very smoothly on GPU.

I was astonished.

Thank you.

Vinaysukhesh98 commented 4 months ago

jax                          0.4.23
jaxlib                       0.4.23+cuda11.cudnn86
chex                         0.1.85
dm-haiku                     0.0.10

Requirement already satisfied: torch==1.13.1 in /usr/local/lib/python3.10/dist-packages (1.13.1) Requirement already satisfied: transformers==4.24.0 in /usr/local/lib/python3.10/dist-packages (4.24.0) Collecting diffusers==0.3.0 Using cached diffusers-0.3.0-py3-none-any.whl (153 kB) Collecting jax==0.4.23 Downloading jax-0.4.23-py3-none-any.whl (1.7 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.7/1.7 MB 6.3 MB/s eta 0:00:00 ERROR: Could not find a version that satisfies the requirement jaxlib==0.4.23+cuda11.cudnn86 (from versions: 0.4.6, 0.4.7, 0.4.9, 0.4.10, 0.4.11, 0.4.12, 0.4.13, 0.4.14, 0.4.16, 0.4.17, 0.4.18, 0.4.19, 0.4.20, 0.4.21, 0.4.22, 0.4.23, 0.4.24, 0.4.25, 0.4.26, 0.4.27, 0.4.28) ERROR: No matching distribution found for jaxlib==0.4.23+cuda11.cudnn86

YoshitakaMo commented 4 months ago

I updated the installer and updater script for Linux two days ago as Jax 0.4.23 no longer seems suitable for cuda 12 and cudnn 9. Please update your cuda to 12.4, cudnn to 9, and use the latest updater script.