Closed Immortals-33 closed 1 month ago
Hello,
I have the same problem, two weeks later (Jun 13, 2024) and running on a test FASTA with six peptides. Setup: Linux cluster with SLURM. CUDA tools v12.4.131 and gcc v11.2.0 Successfully installed and also updated (update_linux.sh) but I also get:
RuntimeError: jax.tools.colab_tpu.setup_tpu() was required for older JAX versions running on older generations of TPUs, and should no longer be used.
Workaround 1 worked for me -- and CPU-only is okay for me for the short term. In batch.py at line 1278:
try:
# check if TPU is available
import jax.tools.colab_tpu
# jmagasin: For now force device to "cpu" because setup_tpu() not supported for newer JAX versions,
#jax.tools.colab_tpu.setup_tpu()
#logger.info('Running on TPU')
#DEVICE = "tpu"
logger.info('Running on CPU if only to avoid calling jax.tools.colab_tpu.setup_tpu()')
DEVICE = "cpu"
use_gpu_relax = False
This thread from March 2023 suggests using the older jax and jaxlib v0.3.25, but it seems that led to other problems.
-Jonathan
Same here. I had a working colab fold install about 6 months ago. For various reasons I had to wipe and reinstall from yesterday. Had the same JAX issue. Sounds like maybe the JAX dependency needs to be pinned in the install_colabbatch_linux.sh and update_linux.sh scripts?
2024-06-18 08:52:05,501 Running colabfold 1.5.5 (1648d2335943f9a483b6a803ebaea3e76162c788)
Traceback (most recent call last):
File "/home/cornlab/bin/localcolabfold/colabfold-conda/lib/python3.10/site-packages/colabfold/batch.py", line 1281, in run
jax.tools.colab_tpu.setup_tpu()
File "/home/cornlab/bin/localcolabfold/colabfold-conda/lib/python3.10/site-packages/jax/tools/colab_tpu.py", line 20, in setup_tpu
raise RuntimeError(
RuntimeError: jax.tools.colab_tpu.setup_tpu() was required for older JAX versions running on older generations of TPUs, and should no longer be used.
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/home/cornlab/bin/localcolabfold/colabfold-conda/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 874, in backends
backend = _init_backend(platform)
File "/home/cornlab/bin/localcolabfold/colabfold-conda/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 965, in _init_backend
backend = registration.factory()
File "/home/cornlab/bin/localcolabfold/colabfold-conda/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 657, in factory
xla_client.initialize_pjrt_plugin(plugin_name)
File "/home/cornlab/bin/localcolabfold/colabfold-conda/lib/python3.10/site-packages/jaxlib/xla_client.py", line 176, in initialize_pjrt_plugin
_xla.initialize_pjrt_plugin(plugin_name)
jaxlib.xla_extension.XlaRuntimeError: INVALID_ARGUMENT: Mismatched PJRT plugin PJRT API version (0.54) and framework PJRT API version 0.51).
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/home/cornlab/bin/localcolabfold/colabfold-conda/bin/colabfold_batch", line 8, in <module>
sys.exit(main())
File "/home/cornlab/bin/localcolabfold/colabfold-conda/lib/python3.10/site-packages/colabfold/batch.py", line 2046, in main
run(
File "/home/cornlab/bin/localcolabfold/colabfold-conda/lib/python3.10/site-packages/colabfold/batch.py", line 1286, in run
if jax.local_devices()[0].platform == 'cpu':
File "/home/cornlab/bin/localcolabfold/colabfold-conda/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 1135, in local_devices
process_index = get_backend(backend).process_index()
File "/home/cornlab/bin/localcolabfold/colabfold-conda/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 1011, in get_backend
return _get_backend_uncached(platform)
File "/home/cornlab/bin/localcolabfold/colabfold-conda/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 990, in _get_backend_uncached
bs = backends()
File "/home/cornlab/bin/localcolabfold/colabfold-conda/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 890, in backends
raise RuntimeError(err_msg)
RuntimeError: Unable to initialize backend 'cuda': INVALID_ARGUMENT: Mismatched PJRT plugin PJRT API version (0.54) and framework PJRT API version 0.51). (you may need to uninstall the failing plugin package, or set JAX_PLATFORMS=cpu to skip this backend.)```
@jecorn your issue is related to this one: https://github.com/YoshitakaMo/localcolabfold/issues/240#issuecomment-2175403799 . This is caused by JAX 0.4.29.
@jdmagasin thanks for the sharing. Tried this and it indeed works a little bit. It did not throw the error of my original ones, but instead throws an CUDA not enough memory
error when predicting structures (i.e. It output the MSA files smoothly but failed to go to the prediction step), which still occured even when I drop the sequence length to about ~20 AAs. Maybe next time I'll try it again with more CPU cores and see if it works.
And @jecorn your error seems similar but a bit little from the one of mine and hopefully the solution @YoshitakaMo mentioned can solve yours. My guess is that both errors arise from some incompatible issues of JAX but not sure yet.
@Immortals-33 I met the same error, any solutions to it ?
Yes, this solution by @YoshitakaMo worked for me. It is issue #240 mentioned above. To apply this solution I modified the installation script, substituting in "jax[cuda12]==0.4.28"
for "jax[cuda12]"
. (Small change to ending " from original solution but doesn't matter.) Actually, I use a (SLURM) job script that (1) downloads the install script; (2) replaces the jax version; (3) runs the fixed install script. I should probably run the upgrade script similarly. Hope this helps.
Hi @ZyuanZhang, the solution proposed by @jdmagasin works for me from time to time (i.e. Work in a machine but not in another). I suggest trying this workaround at the moment and see if it works for you.
Thank you @Immortals-33 , I have tried to install an old version (1.5.0 beta) and it works finally.
Hello, thanks for making ColabFold to be run locally! I'm running into some issues when trying to running
colabfold_batch
on CPU.What is your installation issue? The installation went successful (no error messages indicated), but failed in a CPU mode when running
colabfold_batch
. The main problem might arises from the `jaxComputational environment
To Reproduce
Steps to reproduce the behavior:
What I've tried but didn't work:
DEVICE
inrun()
function insidebatch.py
for CPU;jax.config.update("jax_platform_name", "cpu")
os.environ['JAX_PLATFORMS'] = "cpu")
(from https://github.com/google/jax/discussions/14208)Expected behavior Seems like this is a problem arised from
jax
but not ColabFold or LocalColabFold itself, as I've installed LocalColabFold on another system about one year ago and it works (but that environment was not used anymore). My hypothesis that it might be a problem ofjax
version, but I went through some of the issues and forums on both side, so I post this to see if anyone had encountered the same issue. Since the GPU is not always available on my side I'm trying to use CPU to inference.Thanks in advance!