YoshitakaMo / localcolabfold

ColabFold on your local PC
MIT License
563 stars 129 forks source link

Question: LocalColabFold failed on CPU mode #238

Open Immortals-33 opened 4 months ago

Immortals-33 commented 4 months ago

Hello, thanks for making ColabFold to be run locally! I'm running into some issues when trying to running colabfold_batch on CPU.

What is your installation issue? The installation went successful (no error messages indicated), but failed in a CPU mode when running colabfold_batch. The main problem might arises from the `jax

Computational environment

To Reproduce

Steps to reproduce the behavior:

Traceback (most recent call last):
  File "/dssg/home/acct-clschf/clschf/zzq/localcolabfold/colabfold-conda/lib/python3.10/site-packages/colabfold/batch.py", line 1283, in run
    jax.tools.colab_tpu.setup_tpu()
  File "/dssg/home/acct-clschf/clschf/zzq/localcolabfold/colabfold-conda/lib/python3.10/site-packages/jax/tools/colab_tpu.py", line 20, in setup_tpu
    raise RuntimeError(
RuntimeError: jax.tools.colab_tpu.setup_tpu() was required for older JAX versions running on older generations of TPUs, and should no longer be used.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/dssg/home/acct-clschf/clschf/zzq/localcolabfold/colabfold-conda/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 874, in backends
    backend = _init_backend(platform)
  File "/dssg/home/acct-clschf/clschf/zzq/localcolabfold/colabfold-conda/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 965, in _init_backend
    backend = registration.factory()
  File "/dssg/home/acct-clschf/clschf/zzq/localcolabfold/colabfold-conda/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 663, in factory
    return xla_client.make_c_api_client(plugin_name, updated_options, None)
  File "/dssg/home/acct-clschf/clschf/zzq/localcolabfold/colabfold-conda/lib/python3.10/site-packages/jaxlib/xla_client.py", line 199, in make_c_api_client
    return _xla.get_c_api_client(plugin_name, options, distributed_client)
jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: No visible GPU devices.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/dssg/home/acct-clschf/clschf/zzq/localcolabfold/colabfold-conda/bin/colabfold_batch", line 8, in <module>
    sys.exit(main())
  File "/dssg/home/acct-clschf/clschf/zzq/localcolabfold/colabfold-conda/lib/python3.10/site-packages/colabfold/batch.py", line 2048, in main
    run(
  File "/dssg/home/acct-clschf/clschf/zzq/localcolabfold/colabfold-conda/lib/python3.10/site-packages/colabfold/batch.py", line 1288, in run
    if jax.local_devices()[0].platform == 'cpu':
  File "/dssg/home/acct-clschf/clschf/zzq/localcolabfold/colabfold-conda/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 1135, in local_devices
    process_index = get_backend(backend).process_index()
  File "/dssg/home/acct-clschf/clschf/zzq/localcolabfold/colabfold-conda/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 1011, in get_backend
    return _get_backend_uncached(platform)
  File "/dssg/home/acct-clschf/clschf/zzq/localcolabfold/colabfold-conda/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 990, in _get_backend_uncached
    bs = backends()
  File "/dssg/home/acct-clschf/clschf/zzq/localcolabfold/colabfold-conda/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 890, in backends
    raise RuntimeError(err_msg)
RuntimeError: Unable to initialize backend 'cuda': FAILED_PRECONDITION: No visible GPU devices. (you may need to uninstall the failing plugin package, or set JAX_PLATFORMS=cpu to skip this backend.)

What I've tried but didn't work:

  1. Manually set DEVICE in run() function inside batch.py for CPU;
  2. Add jax.config.update("jax_platform_name", "cpu")
  3. Add os.environ['JAX_PLATFORMS'] = "cpu") (from https://github.com/google/jax/discussions/14208)

Expected behavior Seems like this is a problem arised from jax but not ColabFold or LocalColabFold itself, as I've installed LocalColabFold on another system about one year ago and it works (but that environment was not used anymore). My hypothesis that it might be a problem of jax version, but I went through some of the issues and forums on both side, so I post this to see if anyone had encountered the same issue. Since the GPU is not always available on my side I'm trying to use CPU to inference.

Thanks in advance!

jdmagasin commented 3 months ago

Hello,

I have the same problem, two weeks later (Jun 13, 2024) and running on a test FASTA with six peptides. Setup: Linux cluster with SLURM. CUDA tools v12.4.131 and gcc v11.2.0 Successfully installed and also updated (update_linux.sh) but I also get:

RuntimeError: jax.tools.colab_tpu.setup_tpu() was required for older JAX versions running on older generations of TPUs, and should no longer be used.

Workaround 1 worked for me -- and CPU-only is okay for me for the short term. In batch.py at line 1278:

    try:
        # check if TPU is available                                                                                                            
        import jax.tools.colab_tpu
        # jmagasin: For now force device to "cpu"  because setup_tpu() not supported for newer JAX versions,                                              
        #jax.tools.colab_tpu.setup_tpu()                                                                                                       
        #logger.info('Running on TPU')                                                                                                         
        #DEVICE = "tpu"                                                                                                                        
        logger.info('Running on CPU if only to avoid calling jax.tools.colab_tpu.setup_tpu()')
        DEVICE = "cpu"
        use_gpu_relax = False

This thread from March 2023 suggests using the older jax and jaxlib v0.3.25, but it seems that led to other problems.

-Jonathan

jecorn commented 3 months ago

Same here. I had a working colab fold install about 6 months ago. For various reasons I had to wipe and reinstall from yesterday. Had the same JAX issue. Sounds like maybe the JAX dependency needs to be pinned in the install_colabbatch_linux.sh and update_linux.sh scripts?


2024-06-18 08:52:05,501 Running colabfold 1.5.5 (1648d2335943f9a483b6a803ebaea3e76162c788)
Traceback (most recent call last):
  File "/home/cornlab/bin/localcolabfold/colabfold-conda/lib/python3.10/site-packages/colabfold/batch.py", line 1281, in run
    jax.tools.colab_tpu.setup_tpu()
  File "/home/cornlab/bin/localcolabfold/colabfold-conda/lib/python3.10/site-packages/jax/tools/colab_tpu.py", line 20, in setup_tpu
    raise RuntimeError(
RuntimeError: jax.tools.colab_tpu.setup_tpu() was required for older JAX versions running on older generations of TPUs, and should no longer be used.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/cornlab/bin/localcolabfold/colabfold-conda/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 874, in backends
    backend = _init_backend(platform)
  File "/home/cornlab/bin/localcolabfold/colabfold-conda/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 965, in _init_backend
    backend = registration.factory()
  File "/home/cornlab/bin/localcolabfold/colabfold-conda/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 657, in factory
    xla_client.initialize_pjrt_plugin(plugin_name)
  File "/home/cornlab/bin/localcolabfold/colabfold-conda/lib/python3.10/site-packages/jaxlib/xla_client.py", line 176, in initialize_pjrt_plugin
    _xla.initialize_pjrt_plugin(plugin_name)
jaxlib.xla_extension.XlaRuntimeError: INVALID_ARGUMENT: Mismatched PJRT plugin PJRT API version (0.54) and framework PJRT API version 0.51).

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/cornlab/bin/localcolabfold/colabfold-conda/bin/colabfold_batch", line 8, in <module>
    sys.exit(main())
  File "/home/cornlab/bin/localcolabfold/colabfold-conda/lib/python3.10/site-packages/colabfold/batch.py", line 2046, in main
    run(
  File "/home/cornlab/bin/localcolabfold/colabfold-conda/lib/python3.10/site-packages/colabfold/batch.py", line 1286, in run
    if jax.local_devices()[0].platform == 'cpu':
  File "/home/cornlab/bin/localcolabfold/colabfold-conda/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 1135, in local_devices
    process_index = get_backend(backend).process_index()
  File "/home/cornlab/bin/localcolabfold/colabfold-conda/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 1011, in get_backend
    return _get_backend_uncached(platform)
  File "/home/cornlab/bin/localcolabfold/colabfold-conda/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 990, in _get_backend_uncached
    bs = backends()
  File "/home/cornlab/bin/localcolabfold/colabfold-conda/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 890, in backends
    raise RuntimeError(err_msg)
RuntimeError: Unable to initialize backend 'cuda': INVALID_ARGUMENT: Mismatched PJRT plugin PJRT API version (0.54) and framework PJRT API version 0.51). (you may need to uninstall the failing plugin package, or set JAX_PLATFORMS=cpu to skip this backend.)```
YoshitakaMo commented 3 months ago

@jecorn your issue is related to this one: https://github.com/YoshitakaMo/localcolabfold/issues/240#issuecomment-2175403799 . This is caused by JAX 0.4.29.

Immortals-33 commented 3 months ago

@jdmagasin thanks for the sharing. Tried this and it indeed works a little bit. It did not throw the error of my original ones, but instead throws an CUDA not enough memory error when predicting structures (i.e. It output the MSA files smoothly but failed to go to the prediction step), which still occured even when I drop the sequence length to about ~20 AAs. Maybe next time I'll try it again with more CPU cores and see if it works.

And @jecorn your error seems similar but a bit little from the one of mine and hopefully the solution @YoshitakaMo mentioned can solve yours. My guess is that both errors arise from some incompatible issues of JAX but not sure yet.

ZyuanZhang commented 3 months ago

@Immortals-33 I met the same error, any solutions to it ?

jdmagasin commented 3 months ago

Yes, this solution by @YoshitakaMo worked for me. It is issue #240 mentioned above. To apply this solution I modified the installation script, substituting in "jax[cuda12]==0.4.28" for "jax[cuda12]". (Small change to ending " from original solution but doesn't matter.) Actually, I use a (SLURM) job script that (1) downloads the install script; (2) replaces the jax version; (3) runs the fixed install script. I should probably run the upgrade script similarly. Hope this helps.

Immortals-33 commented 2 months ago

Hi @ZyuanZhang, the solution proposed by @jdmagasin works for me from time to time (i.e. Work in a machine but not in another). I suggest trying this workaround at the moment and see if it works for you.

ZyuanZhang commented 2 months ago

Thank you @Immortals-33 , I have tried to install an old version (1.5.0 beta) and it works finally.