kalininalab / alphafold_non_docker

AlphaFold2 non-docker setup
345 stars 120 forks source link

How to run AlphaFold non docker in NVidia 4090 card? #77

Open longbow0 opened 4 months ago

longbow0 commented 4 months ago

AlphaFold non docker with AlphaFold 2.3.1 worked well in a NVidia V100 card. But it cannot work on a 4090 card.

I have followed the instruction in AlphaFold discussion: https://github.com/google-deepmind/alphafold/issues/786

and made modification in these steps:

# Install cudatoolkit 11.8
conda install -y -c conda-forge openmm==7.5.1 cudatoolkit==11.8 pdbfixer

# Install jaxlib 0.4.13
pip install --upgrade --no-cache-dir jax==0.4.13 jaxlib==0.4.13+cuda11.cudnn86 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

These was no error while installing. But when running prediction, the error information was:

...
I0703 12:56:00.237594 140343651034944 run_alphafold.py:191] Running model model_1_pred_0 on s_rbd
I0703 12:56:02.881171 140343651034944 model.py:165] Running predict with shape(feat) = {'aatype': (4, 222), 'residue_index': (4, 222), 'seq_length': (4,), 'template_aatype': (4, 4, 222), 'template_all_atom_masks': (4, 4, 222, 37), 'template_all_atom_positions': (4, 4, 222, 37, 3), 'template_sum_probs': (4, 4, 1), 'is_distillation': (4,), 'seq_mask': (4, 222), 'msa_mask': (4, 508, 222), 'msa_row_mask': (4, 508), 'random_crop_to_size_seed': (4, 2), 'template_mask': (4, 4), 'template_pseudo_beta': (4, 4, 222, 3), 'template_pseudo_beta_mask': (4, 4, 222), 'atom14_atom_exists': (4, 222, 14), 'residx_atom14_to_atom37': (4, 222, 14), 'residx_atom37_to_atom14': (4, 222, 37), 'atom37_atom_exists': (4, 222, 37), 'extra_msa': (4, 5120, 222), 'extra_msa_mask': (4, 5120, 222), 'extra_msa_row_mask': (4, 5120), 'bert_mask': (4, 508, 222), 'true_msa': (4, 508, 222), 'extra_has_deletion': (4, 5120, 222), 'extra_deletion_value': (4, 5120, 222), 'msa_feat': (4, 508, 222, 49), 'target_feat': (4, 222, 22)}
2024-07-03 12:56:02.906804: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:439] Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR
2024-07-03 12:56:02.906866: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:443] Memory usage: 22732734464 bytes free, 25390809088 bytes total.
Traceback (most recent call last):
  File "/home/server/alphafold-2.3.1/run_alphafold.py", line 432, in <module>
    app.run(main)
  File "/home/server/miniconda3/envs/alphafold-ndd/lib/python3.8/site-packages/absl/app.py", line 312, in run
    _run_main(main, args)
  File "/home/server/miniconda3/envs/alphafold-ndd/lib/python3.8/site-packages/absl/app.py", line 258, in _run_main
    sys.exit(main(argv))
  File "/home/server/alphafold-2.3.1/run_alphafold.py", line 408, in main
    predict_structure(
  File "/home/server/alphafold-2.3.1/run_alphafold.py", line 199, in predict_structure
    prediction_result = model_runner.predict(processed_feature_dict,
  File "/home/server/alphafold-2.3.1/alphafold/model/model.py", line 167, in predict
    result = self.apply(self.params, jax.random.PRNGKey(random_seed), feat)
  File "/home/server/miniconda3/envs/alphafold-ndd/lib/python3.8/site-packages/jax/_src/random.py", line 160, in PRNGKey
    key = prng.seed_with_impl(impl, seed)
  File "/home/server/miniconda3/envs/alphafold-ndd/lib/python3.8/site-packages/jax/_src/prng.py", line 406, in seed_with_impl
    return random_seed(seed, impl=impl)
  File "/home/server/miniconda3/envs/alphafold-ndd/lib/python3.8/site-packages/jax/_src/prng.py", line 690, in random_seed
    return random_seed_p.bind(seeds_arr, impl=impl)
  File "/home/server/miniconda3/envs/alphafold-ndd/lib/python3.8/site-packages/jax/_src/prng.py", line 702, in random_seed_impl
    base_arr = random_seed_impl_base(seeds, impl=impl)
  File "/home/server/miniconda3/envs/alphafold-ndd/lib/python3.8/site-packages/jax/_src/prng.py", line 707, in random_seed_impl_base
    return seed(seeds)
  File "/home/server/miniconda3/envs/alphafold-ndd/lib/python3.8/site-packages/jax/_src/prng.py", line 936, in threefry_seed
    return _threefry_seed(seed)
  File "/home/server/miniconda3/envs/alphafold-ndd/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/server/miniconda3/envs/alphafold-ndd/lib/python3.8/site-packages/jax/_src/pjit.py", line 250, in cache_miss
    outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
  File "/home/server/miniconda3/envs/alphafold-ndd/lib/python3.8/site-packages/jax/_src/pjit.py", line 163, in _python_pjit_helper
    out_flat = pjit_p.bind(*args_flat, **params)
  File "/home/server/miniconda3/envs/alphafold-ndd/lib/python3.8/site-packages/jax/_src/core.py", line 2677, in bind
    return self.bind_with_trace(top_trace, args, params)
  File "/home/server/miniconda3/envs/alphafold-ndd/lib/python3.8/site-packages/jax/_src/core.py", line 383, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/home/server/miniconda3/envs/alphafold-ndd/lib/python3.8/site-packages/jax/_src/core.py", line 815, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/home/server/miniconda3/envs/alphafold-ndd/lib/python3.8/site-packages/jax/_src/pjit.py", line 1203, in _pjit_call_impl
    return xc._xla.pjit(name, f, call_impl_cache_miss, [], [], donated_argnums,
  File "/home/server/miniconda3/envs/alphafold-ndd/lib/python3.8/site-packages/jax/_src/pjit.py", line 1187, in call_impl_cache_miss
    out_flat, compiled = _pjit_call_impl_python(
  File "/home/server/miniconda3/envs/alphafold-ndd/lib/python3.8/site-packages/jax/_src/pjit.py", line 1120, in _pjit_call_impl_python
    compiled = _pjit_lower(
  File "/home/server/miniconda3/envs/alphafold-ndd/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 2323, in compile
    executable = UnloadedMeshExecutable.from_hlo(
  File "/home/server/miniconda3/envs/alphafold-ndd/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 2645, in from_hlo
    xla_executable, compile_options = _cached_compilation(
  File "/home/server/miniconda3/envs/alphafold-ndd/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 2555, in _cached_compilation
    xla_executable = dispatch.compile_or_get_cached(
  File "/home/server/miniconda3/envs/alphafold-ndd/lib/python3.8/site-packages/jax/_src/dispatch.py", line 497, in compile_or_get_cached
    return backend_compile(backend, computation, compile_options,
  File "/home/server/miniconda3/envs/alphafold-ndd/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/home/server/miniconda3/envs/alphafold-ndd/lib/python3.8/site-packages/jax/_src/dispatch.py", line 465, in backend_compile
    return backend.compile(built_c, compile_options=options)
jax._src.traceback_util.UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/server/alphafold-2.3.1/run_alphafold.py", line 432, in <module>
    app.run(main)
  File "/home/server/miniconda3/envs/alphafold-ndd/lib/python3.8/site-packages/absl/app.py", line 312, in run
    _run_main(main, args)
  File "/home/server/miniconda3/envs/alphafold-ndd/lib/python3.8/site-packages/absl/app.py", line 258, in _run_main
    sys.exit(main(argv))
  File "/home/server/alphafold-2.3.1/run_alphafold.py", line 408, in main
    predict_structure(
  File "/home/server/alphafold-2.3.1/run_alphafold.py", line 199, in predict_structure
    prediction_result = model_runner.predict(processed_feature_dict,
  File "/home/server/alphafold-2.3.1/alphafold/model/model.py", line 167, in predict
    result = self.apply(self.params, jax.random.PRNGKey(random_seed), feat)
  File "/home/server/miniconda3/envs/alphafold-ndd/lib/python3.8/site-packages/jax/_src/random.py", line 160, in PRNGKey
    key = prng.seed_with_impl(impl, seed)
  File "/home/server/miniconda3/envs/alphafold-ndd/lib/python3.8/site-packages/jax/_src/prng.py", line 406, in seed_with_impl
    return random_seed(seed, impl=impl)
  File "/home/server/miniconda3/envs/alphafold-ndd/lib/python3.8/site-packages/jax/_src/prng.py", line 690, in random_seed
    return random_seed_p.bind(seeds_arr, impl=impl)
  File "/home/server/miniconda3/envs/alphafold-ndd/lib/python3.8/site-packages/jax/_src/core.py", line 380, in bind
    return self.bind_with_trace(find_top_trace(args), args, params)
  File "/home/server/miniconda3/envs/alphafold-ndd/lib/python3.8/site-packages/jax/_src/core.py", line 383, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/home/server/miniconda3/envs/alphafold-ndd/lib/python3.8/site-packages/jax/_src/core.py", line 815, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/home/server/miniconda3/envs/alphafold-ndd/lib/python3.8/site-packages/jax/_src/prng.py", line 702, in random_seed_impl
    base_arr = random_seed_impl_base(seeds, impl=impl)
  File "/home/server/miniconda3/envs/alphafold-ndd/lib/python3.8/site-packages/jax/_src/prng.py", line 707, in random_seed_impl_base
    return seed(seeds)
  File "/home/server/miniconda3/envs/alphafold-ndd/lib/python3.8/site-packages/jax/_src/prng.py", line 936, in threefry_seed
    return _threefry_seed(seed)
jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.

How to fix this error? Thanks!