chaidiscovery / chai-lab

Chai-1, SOTA model for biomolecular structure prediction
https://www.chaidiscovery.com
Other
1.02k stars 131 forks source link

Cuda 11.2 Compatibility? #38

Open alexberlaga opened 1 week ago

alexberlaga commented 1 week ago

This is what I get when I run this on a GPU with cuda version 11.2:

Traceback (most recent call last):
  File "/project2/andrewferguson/berlaga/drugdiscovery/run_chai.py", line 60, in <module>
    output_pdb_paths = run_inference(
                       ^^^^^^^^^^^^^^
  File "/scratch/midway3/berlaga/miniconda3/envs/chai/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/scratch/midway3/berlaga/miniconda3/envs/chai/lib/python3.12/site-packages/chai_lab/chai1.py", line 276, in run_inference
    embedding_context = get_esm_embedding_context(chains, device=device)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/scratch/midway3/berlaga/miniconda3/envs/chai/lib/python3.12/site-packages/jaxtyping/_decorator.py", line 522, in wrapped_fn
    return wrapped_fn_impl(args, kwargs, bound, memos)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/scratch/midway3/berlaga/miniconda3/envs/chai/lib/python3.12/site-packages/jaxtyping/_decorator.py", line 449, in wrapped_fn_impl
    out = fn(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^
  File "/scratch/midway3/berlaga/miniconda3/envs/chai/lib/python3.12/site-packages/chai_lab/data/dataset/embeddings/esm.py", line 79, in get_esm_embedding_context
    emb = embedding_context_from_sequence(
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/scratch/midway3/berlaga/miniconda3/envs/chai/lib/python3.12/site-packages/chai_lab/data/dataset/embeddings/esm.py", line 50, in embedding_context_from_sequence
    esm_model.eval().to(device)
  File "/scratch/midway3/berlaga/miniconda3/envs/chai/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1173, in to
    return self._apply(convert)
           ^^^^^^^^^^^^^^^^^^^^
  File "/scratch/midway3/berlaga/miniconda3/envs/chai/lib/python3.12/site-packages/torch/nn/modules/module.py", line 779, in _apply
    module._apply(fn)
  File "/scratch/midway3/berlaga/miniconda3/envs/chai/lib/python3.12/site-packages/torch/nn/modules/module.py", line 804, in _apply
    param_applied = fn(param)
                    ^^^^^^^^^
  File "/scratch/midway3/berlaga/miniconda3/envs/chai/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1159, in convert
    return t.to(
           ^^^^^
  File "/scratch/midway3/berlaga/miniconda3/envs/chai/lib/python3.12/site-packages/torch/cuda/__init__.py", line 293, in _lazy_init
    torch._C._cuda_init()
RuntimeError: The NVIDIA driver on your system is too old (found version 11020). Please update your GPU driver by downloading and installing a new version from the URL: http://www.nvidia.com/Download/index.aspx Alternatively, go to: https://pytorch.org to install a PyTorch version that has been compiled with your version of the CUDA driver.

Is there a set of changes I can make to the dependencies that make this code compatible with Cuda/11.2?

Thanks!

arogozhnikov commented 1 week ago

Current torch versions support 11.8 and 12.4, and this is for a while.

As a suggestion try following:

  1. install package
  2. try installing other torch 2.3 for cuda 11.8 using one of commands from here: https://pytorch.org/get-started/previous-versions/