Closed ZubuNoShoshinsha closed 6 months ago
I suspect that the issue lies in the version of the dm-haiku module being 0.0.11 or later. In my environment:
$ localcolabfold/colabfold-conda/bin/python3.10 -m pip list
jax 0.4.23
jaxlib 0.4.23+cuda11.cudnn86
chex 0.1.85
dm-haiku 0.0.10
If CUDA 12.1 is installed, these versions should be fine.
Please set your dm-haiku to version 0.0.10. Otherwise, you may encounter the error ModuleNotFoundError: No module named 'jax.extend'
Thank you for your suggestion. (I just noticed your response) Actually my dm-haiku was 0.0.12, so I down graded to 0.0.10 as you suggested.
And I ran localcolabfold 1.5.5. So, my environment is now jax 0.4.7 jaxlib 0.4.7+cuda11.cudnn82 chex 0.1.82 dm-haiku 0.0.10
No more " ModuleNotFoundError: No module named 'jax.extend' ", but now new message showed up and the program stopped.
" Could not predict ProteinA. Not Enough GPU memory? FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details. "
Is there any more suggestion to solve this??
my nvidia-smi result
If you are using WSL2, did you turn on the settings shown in https://github.com/YoshitakaMo/localcolabfold?tab=readme-ov-file#for-wsl2-in-windows ? Unfortunately, I can't figure out the cause because I don't have a WSL2 environment.
Yes, I did. I restart wsl2 and tried another shot, but it didn't work well.
Thank you though.
I wonder... when I downgraded dm-haiku, the message said " colabfold 1.5.5 requires dm-haiku<0.013, >=0.0.12, but you have dm-haiku 0.0.10 which is incompatible. " Is it fine to run colabfold appropriately?
Finally,
I might have found the solution.
I downgraded " nvidia-cudnn-cu11 " by doing this command from 9.0.0.312 to 8.5.0.96 .
pip install --upgrade nvidia-cudnn-cu11==8.5.0.96
I ran the localcolabfold and it processed very smoothly on GPU.
I was astonished.
Thank you.
jax 0.4.23 jaxlib 0.4.23+cuda11.cudnn86 chex 0.1.85 dm-haiku 0.0.10
Requirement already satisfied: torch==1.13.1 in /usr/local/lib/python3.10/dist-packages (1.13.1) Requirement already satisfied: transformers==4.24.0 in /usr/local/lib/python3.10/dist-packages (4.24.0) Collecting diffusers==0.3.0 Using cached diffusers-0.3.0-py3-none-any.whl (153 kB) Collecting jax==0.4.23 Downloading jax-0.4.23-py3-none-any.whl (1.7 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.7/1.7 MB 6.3 MB/s eta 0:00:00 ERROR: Could not find a version that satisfies the requirement jaxlib==0.4.23+cuda11.cudnn86 (from versions: 0.4.6, 0.4.7, 0.4.9, 0.4.10, 0.4.11, 0.4.12, 0.4.13, 0.4.14, 0.4.16, 0.4.17, 0.4.18, 0.4.19, 0.4.20, 0.4.21, 0.4.22, 0.4.23, 0.4.24, 0.4.25, 0.4.26, 0.4.27, 0.4.28) ERROR: No matching distribution found for jaxlib==0.4.23+cuda11.cudnn86
I updated the installer and updater script for Linux two days ago as Jax 0.4.23 no longer seems suitable for cuda 12 and cudnn 9. Please update your cuda to 12.4, cudnn to 9, and use the latest updater script.
Hello,
My question is related to #209 and #210 My environment is...
Wsl2 OS: Ubuntu 22.04.4 GCC: 11.4.0 CUDA: 12.1 GPU: RTX 4090 LocalColabFold Ver. 1.5.5
As instructed in #209 , I checked if GPU was recognized and it was not. So, I dongraded jax and jaxlib to jax 0.4.7 jaxlib0.4.7+cuda11.cudnn82 as instructed in #209 .
And then I checked again using $ /path/to/your/localcolabfold/colabfold-conda/bin/python3.10
and "gpu" was returned.
Then, I run the localcolabfold. But, this error message popped up and stopped like below
2024-04-01 15:14:35,452 Running colabfold 1.5.5 (61df3b853140ca79dbdf64349824beb14364ebfd) 2024-04-01 15:14:36,006 Running on GPU Traceback (most recent call last): File "/mnt/d/Alphafold/localcolabfold/colabfold-conda/bin/colabfold_batch", line 8, in sys.exit(main())
File "/mnt/d/AlphaFold/localcolabfold/colabfold-conda/lib/python3.10/site-packages/colabfold/batch.py", line 2037, in main run(
File "/mnt/d/AlphaFold/localcolabfold/colabfold-conda/lib/python3.10/site-packages/colabfold/batch.py", line 1292, in run from colabfold.alphafold.models import load_models_and_params
File "/mnt/d/AlphaFold/localcolabfold/colabfold-conda/lib/python3.10/site-packages/colabfold/alphafold/models.py", line 4, in import haiku
File "/mnt/d/AlphaFold/localcolabfold/colabfold-conda/lib/python3.10/site-packages/haiku/init.py", line 20, in from haiku import experimental
File "/mnt/d/AlphaFold/localcolabfold/colabfold-conda/lib/python3.10/site-packages/haiku/experimental/init.py", line 34, in from haiku._src.dot import abstract_to_dot
File "/mnt/d/AlphaFold/localcolabfold/colabfold-conda/lib/python3.10/site-packages/haiku/_src/dot.py", line 29, in from jax.extend import linear_util as lu
ModuleNotFoundError: No module named 'jax.extend'
It would be helpful if there would be any instruction for solving this issue.