jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.25k stars 2.77k forks source link

jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed #15361

Closed Toni-SM closed 1 year ago

Toni-SM commented 1 year ago

Description

I have a python virtual environment with a clean installation of JAX

# Installs the wheel compatible with CUDA 12 and cuDNN 8.8 or newer.
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda12_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

When I run my scripts, they work perfectly, but sometimes I get the following error with a success rate of between 2 and 10 successful executions and between 1 and 3 failed executions

2023-04-02 16:00:19.964652: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:429] Could not create cudnn handle: CUDNN_STATUS_NOT_INITIALIZED
2023-04-02 16:00:19.964737: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:438] Possibly insufficient driver version: 530.30.2
Traceback (most recent call last):
  File "ddpg_jax_gymnasium_pendulum.py", line 73, in <module>
    key = jax.random.PRNGKey(0)
  File "/home/toni/Documents/SKRL/envs/env_jax/lib/python3.8/site-packages/jax/_src/random.py", line 136, in PRNGKey
    key = prng.seed_with_impl(impl, seed)
  File "/home/toni/Documents/SKRL/envs/env_jax/lib/python3.8/site-packages/jax/_src/prng.py", line 270, in seed_with_impl
    return random_seed(seed, impl=impl)
  File "/home/toni/Documents/SKRL/envs/env_jax/lib/python3.8/site-packages/jax/_src/prng.py", line 561, in random_seed
    return random_seed_p.bind(seeds_arr, impl=impl)
  File "/home/toni/Documents/SKRL/envs/env_jax/lib/python3.8/site-packages/jax/_src/core.py", line 360, in bind
    return self.bind_with_trace(find_top_trace(args), args, params)
  File "/home/toni/Documents/SKRL/envs/env_jax/lib/python3.8/site-packages/jax/_src/core.py", line 363, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/home/toni/Documents/SKRL/envs/env_jax/lib/python3.8/site-packages/jax/_src/core.py", line 817, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/home/toni/Documents/SKRL/envs/env_jax/lib/python3.8/site-packages/jax/_src/prng.py", line 573, in random_seed_impl
    base_arr = random_seed_impl_base(seeds, impl=impl)
  File "/home/toni/Documents/SKRL/envs/env_jax/lib/python3.8/site-packages/jax/_src/prng.py", line 578, in random_seed_impl_base
    return seed(seeds)
  File "/home/toni/Documents/SKRL/envs/env_jax/lib/python3.8/site-packages/jax/_src/prng.py", line 813, in threefry_seed
    lax.shift_right_logical(seed, lax_internal._const(seed, 32)))
  File "/home/toni/Documents/SKRL/envs/env_jax/lib/python3.8/site-packages/jax/_src/lax/lax.py", line 458, in shift_right_logical
    return shift_right_logical_p.bind(x, y)
  File "/home/toni/Documents/SKRL/envs/env_jax/lib/python3.8/site-packages/jax/_src/core.py", line 360, in bind
    return self.bind_with_trace(find_top_trace(args), args, params)
  File "/home/toni/Documents/SKRL/envs/env_jax/lib/python3.8/site-packages/jax/_src/core.py", line 363, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/home/toni/Documents/SKRL/envs/env_jax/lib/python3.8/site-packages/jax/_src/core.py", line 817, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/home/toni/Documents/SKRL/envs/env_jax/lib/python3.8/site-packages/jax/_src/dispatch.py", line 117, in apply_primitive
    compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args),
  File "/home/toni/Documents/SKRL/envs/env_jax/lib/python3.8/site-packages/jax/_src/util.py", line 253, in wrapper
    return cached(config._trace_context(), *args, **kwargs)
  File "/home/toni/Documents/SKRL/envs/env_jax/lib/python3.8/site-packages/jax/_src/util.py", line 246, in cached
    return f(*args, **kwargs)
  File "/home/toni/Documents/SKRL/envs/env_jax/lib/python3.8/site-packages/jax/_src/dispatch.py", line 208, in xla_primitive_callable
    compiled = _xla_callable_uncached(lu.wrap_init(prim_fun), prim.name,
  File "/home/toni/Documents/SKRL/envs/env_jax/lib/python3.8/site-packages/jax/_src/dispatch.py", line 254, in _xla_callable_uncached
    return computation.compile(_allow_propagation_to_outputs=allow_prop).unsafe_call
  File "/home/toni/Documents/SKRL/envs/env_jax/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 2816, in compile
    self._executable = UnloadedMeshExecutable.from_hlo(
  File "/home/toni/Documents/SKRL/envs/env_jax/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 3028, in from_hlo
    xla_executable = dispatch.compile_or_get_cached(
  File "/home/toni/Documents/SKRL/envs/env_jax/lib/python3.8/site-packages/jax/_src/dispatch.py", line 526, in compile_or_get_cached
    return backend_compile(backend, serialized_computation, compile_options,
  File "/home/toni/Documents/SKRL/envs/env_jax/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/home/toni/Documents/SKRL/envs/env_jax/lib/python3.8/site-packages/jax/_src/dispatch.py", line 471, in backend_compile
    return backend.compile(built_c, compile_options=options)
jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.

What jax/jaxlib version are you using?

jax 0.4.8, jaxlib 0.4.7+cuda12.cudnn88

Which accelerator(s) are you using?

GPU

Additional system info

Python 3.8.10, Ubuntu 20.04

NVIDIA GPU info

+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 530.30.02              Driver Version: 530.30.02    CUDA Version: 12.1     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                  Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf            Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA GeForce RTX 3080 L...    On | 00000000:01:00.0 Off |                  N/A |
| N/A   38C    P3               N/A /  55W|     10MiB / 16384MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+

+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|    0   N/A  N/A      1528      G   /usr/lib/xorg/Xorg                            4MiB |
|    0   N/A  N/A      2435      G   /usr/lib/xorg/Xorg                            4MiB |
+---------------------------------------------------------------------------------------+

CUDNN version (/usr/local/cuda/include/cudnn_version.h)

#define CUDNN_MAJOR 8
#define CUDNN_MINOR 8
#define CUDNN_PATCHLEVEL 1
nouiz commented 1 year ago

What is your OS? Can you confirm you run the scripts sequentially and so there is nothing that is using the GPU in parallel?

Toni-SM commented 1 year ago

Hi @nouiz

The OS is Ubuntu 20.04, as indicated above.

Btw, I think the problem may be VS Code. After running the script several times to try to get the error to appear, I see that the error only appears (not always but) when I make a modification to the script and save it.

There is also the following log. As you can see (by running the nvidia-smi command just before executing the script, and after saving it) there is a GPU consumption. The strange thing is that the consumption comes from the python environment (env_gym) configured in the VS Code bottom right pane and not from the python of the sourced environment where jax is installed (env_jax) 🤔

(env_jax) toni@HP-ZBook-Studio-G8:~/Documents/SKRL/skrl/docs/source/examples/gymnasium$ nvidia-smi; python ddpg_jax_gymnasium_pendulum.py 
Mon Apr  3 20:41:40 2023       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 530.30.02              Driver Version: 530.30.02    CUDA Version: 12.1     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                  Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf            Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA GeForce RTX 3080 L...    On | 00000000:01:00.0 Off |                  N/A |
| N/A   49C    P3               23W /  55W|  12453MiB / 16384MiB |      4%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+

+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|    0   N/A  N/A      1536      G   /usr/lib/xorg/Xorg                            4MiB |
|    0   N/A  N/A      2456      G   /usr/lib/xorg/Xorg                            4MiB |
|    0   N/A  N/A     27050      C   .../SKRL/envs/env_gym/bin/python          12440MiB |
+---------------------------------------------------------------------------------------+
[skrl:INFO] Environment class: gymnasium.core.Wrapper, gymnasium.utils.record_constructor.RecordConstructorArgs
[skrl:INFO] Environment wrapper: Gymnasium
2023-04-03 20:41:43.310989: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:429] Could not create cudnn handle: CUDNN_STATUS_NOT_INITIALIZED
2023-04-03 20:41:43.311060: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:438] Possibly insufficient driver version: 530.30.2
Traceback (most recent call last):
...
nouiz commented 1 year ago

Thanks for the results. I think we need a way to give a better error to end users.

nouiz commented 1 year ago

Recently a few error message got a little bit better. Closing as I'm not sure what do to more. But if the issue appear again and the error isn't good enough, poke us again.

amacrutherford commented 1 year ago

I also got this error and it was due to GPU reaching its memory limit

nouiz commented 1 year ago

@amacrutherford Do you have the full error message you had? I would like to improve the error message in that case.

Bailey-24 commented 1 year ago

same error

---------------------------------------------------------------------------
XlaRuntimeError                           Traceback (most recent call last)
[<ipython-input-25-5a28263ee724>](https://localhost:8080/#) in <cell line: 5>()
      3 
      4 # Initialize model weights using dummy tensors.
----> 5 rng = jax.random.PRNGKey(0)
      6 rng, key = jax.random.split(rng)
      7 init_img = jnp.ones((4, 224, 224, 5), jnp.float32)

22 frames
[/usr/local/lib/python3.10/dist-packages/jax/_src/dispatch.py](https://localhost:8080/#) in backend_compile(backend, built_c, options, host_callbacks)
    469   # TODO(sharadmv): remove this fallback when all backends allow `compile`
    470   # to take in `host_callbacks`
--> 471   return backend.compile(built_c, compile_options=options)
    472 
    473 _ir_dump_counter = itertools.count()

XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.

I think there also the memory reaching the limit.

Fri May 12 01:50:33 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   57C    P0    27W /  70W |  15101MiB / 15360MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
+-----------------------------------------------------------------------------+
amacrutherford commented 1 year ago

Yep I received the same error message as @Bailey-24

nouiz commented 1 year ago

Thanks for the extended error message. But can you share the full output without any truncation? There should be information that should help me. Which JAX version did you use?

hosseybposh commented 1 year ago

I'm having the same problem but for me it's consistent and I'm unable to run simple Jax code. I only have this problem on my newest system with 4x RTX 4090 GPUs. I have a server A100 and a PC with a 3090ti that work smoothly. Ubuntu 22 across all systems. First installed CUDA 11 from conda-forge as suggested, same issue. Then switched to loca installation of CUDA and cudnn. Same problem.

After a fresh installation of everything when I run a = jnp.ones((3,)) I get this error:

2023-05-13 09:04:27.790057: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:439] Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR 2023-05-13 09:04:27.790140: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:443] Memory usage: 5853872128 bytes free, 25393692672 bytes total. Traceback (most recent call last): File "", line 1, in File "/home/hoss/anaconda3/envs/jaxtf/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py", line 2122, in ones return lax.full(shape, 1, _jnp_dtype(dtype)) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/hoss/anaconda3/envs/jaxtf/lib/python3.11/site-packages/jax/_src/lax/lax.py", line 1203, in full return broadcast(fill_value, shape) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/hoss/anaconda3/envs/jaxtf/lib/python3.11/site-packages/jax/_src/lax/lax.py", line 768, in broadcast return broadcast_in_dim(operand, tuple(sizes) + np.shape(operand), dims) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/hoss/anaconda3/envs/jaxtf/lib/python3.11/site-packages/jax/_src/lax/lax.py", line 796, in broadcast_in_dim return broadcast_in_dim_p.bind( ^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/hoss/anaconda3/envs/jaxtf/lib/python3.11/site-packages/jax/_src/core.py", line 380, in bind return self.bind_with_trace(find_top_trace(args), args, params) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/hoss/anaconda3/envs/jaxtf/lib/python3.11/site-packages/jax/_src/core.py", line 383, in bind_with_trace out = trace.process_primitive(self, map(trace.full_raise, args), params) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/hoss/anaconda3/envs/jaxtf/lib/python3.11/site-packages/jax/_src/core.py", line 790, in process_primitive return primitive.impl(*tracers, params) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/hoss/anaconda3/envs/jaxtf/lib/python3.11/site-packages/jax/_src/dispatch.py", line 131, in apply_primitive compiled_fun = xla_primitive_callable( ^^^^^^^^^^^^^^^^^^^^^^^ File "/home/hoss/anaconda3/envs/jaxtf/lib/python3.11/site-packages/jax/_src/util.py", line 284, in wrapper return cached(config._trace_context(), *args, *kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/hoss/anaconda3/envs/jaxtf/lib/python3.11/site-packages/jax/_src/util.py", line 277, in cached return f(args, kwargs) ^^^^^^^^^^^^^^^^^^ File "/home/hoss/anaconda3/envs/jaxtf/lib/python3.11/site-packages/jax/_src/dispatch.py", line 222, in xla_primitive_callable compiled = _xla_callable_uncached( ^^^^^^^^^^^^^^^^^^^^^^^ File "/home/hoss/anaconda3/envs/jaxtf/lib/python3.11/site-packages/jax/_src/dispatch.py", line 252, in _xla_callable_uncached return computation.compile().unsafe_call ^^^^^^^^^^^^^^^^^^^^^ File "/home/hoss/anaconda3/envs/jaxtf/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 2313, in compile executable = UnloadedMeshExecutable.from_hlo( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/hoss/anaconda3/envs/jaxtf/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 2633, in from_hlo xla_executable, compile_options = _cached_compilation( ^^^^^^^^^^^^^^^^^^^^ File "/home/hoss/anaconda3/envs/jaxtf/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 2551, in _cached_compilation xla_executable = dispatch.compile_or_get_cached( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/hoss/anaconda3/envs/jaxtf/lib/python3.11/site-packages/jax/_src/dispatch.py", line 495, in compile_or_get_cached return backend_compile(backend, computation, compile_options, ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/hoss/anaconda3/envs/jaxtf/lib/python3.11/site-packages/jax/_src/profiler.py", line 314, in wrapper return func(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^ File "/home/hoss/anaconda3/envs/jaxtf/lib/python3.11/site-packages/jax/_src/dispatch.py", line 463, in backend_compile return backend.compile(built_c, compile_options=options) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.

hosseybposh commented 1 year ago

I've also tried to swap my PCs 3090ti (which works properly) with one of the 4090s and I got the exact same error. This used to work in the past. I'm pretty sure the GPU hardware is not the problem and they are functional (tested them on Windows machines).

nouiz commented 1 year ago

Did you cut the output? The error tell to look above for more errors. If there is more outputs, give me all what you have. I'll filter what is useful or not.

ampolloreno commented 1 year ago

I'm getting the same kind of error trying to install jax/jaxlib on an EC2 p2.xlarge (with k80s), to provide solidarity! I can provide more details if useful, but basically running some vanilla installation script of Anaconda and trying different variants of pip install "jax[cuda11_cudnn82]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html leads Jax to report seeing the GPU when I check print(xla_bridge.get_backend().platform) but gives the DNN error above, otherwise.

ampolloreno commented 1 year ago

(I'm also unable to any Jax code, e.g. a = jnp.ones((3,)).)

hawkinsp commented 1 year ago

@ampolloreno Please open new issues rather than appending onto closed ones.

However, I think the problem in your case is simple: JAX no longer supports Kepler GPUs in the wheels we release. You can probably rebuild jaxlib from source if you need Kepler support, but note NVIDIA has dropped Kepler support from CUDA 12 and CUDNN 8.9, so this may not remain true for long.

hosseybposh commented 1 year ago

@nouiz no this is all the output. It's several lines of error are you seeing all of it?

I managed to resolve this though. I installed CUDA 11 and cudnn 8.6. In my experiments I also installed the latest version of everything but this was the only version combination that worked for me. Now I'm getting other errors but that's a different problem.

ampolloreno commented 1 year ago

@hawkinsp Point taken and... Thanks for the help! I just switched over to V100s and voila!

liyc-ai commented 1 year ago

I got the same error, maybe it due to the mismatch between your cuda version and the installed jax. I use ubuntu 20.04 with cuda version as below: image At start, I installed the newest jax as

pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

Then I got the error as reported. So I switched to

pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

Everything works well!

tuzhucheng commented 1 year ago

Thanks @hosseybposh, for a simple use case I was able to use JAX 0.4.13 and CUDA 11.8 with CUDNN 8.6. I needed to add /usr/lib/x86_64-linux-gnu to the LD_LIBRARY_PATH (installed libcudnn8 with apt-get).

TInaWangxue commented 1 year ago

I also met this error: The output: 2023-07-31 01:53:45.016563: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:427] Loaded runtime CuDNN library: 8.5.0 but source was compiled with: 8.6.0. CuDNN library needs to have matching major version and equal or higher minor version. If using a binary install, upgrade your CuDNN library. If building from sources, make sure the library loaded at runtime is compatible with the version specified during compile configuration.

XlaRuntimeError Traceback (most recent call last) Cell In[4], line 29 26 model = trainer.make_model(nmask) 28 lr_fn, opt = trainer.make_optimizer(steps_per_epoch=len(train_dl)) ---> 29 state = trainer.create_train_state(jax.random.PRNGKey(0), model, opt) 30 state = checkpoints.restore_checkpoint(ckpt.parent, state)

File /mnt/data/miniconda/envs/energy_transformer_117/lib/python3.11/site-packages/jax/_src/random.py:137, in PRNGKey(seed) 134 if np.ndim(seed): 135 raise TypeError("PRNGKey accepts a scalar seed, but was given an array of" 136 f"shape {np.shape(seed)} != (). Use jax.vmap for batching") --> 137 key = prng.seed_with_impl(impl, seed) 138 return _return_prng_keys(True, key)

File /mnt/data/miniconda/envs/energy_transformer_117/lib/python3.11/site-packages/jax/_src/prng.py:320, in seed_with_impl(impl, seed) 319 def seed_with_impl(impl: PRNGImpl, seed: Union[int, Array]) -> PRNGKeyArrayImpl: --> 320 return random_seed(seed, impl=impl)

File /mnt/data/miniconda/envs/energy_transformer_117/lib/python3.11/site-packages/jax/_src/prng.py:734, in random_seed(seeds, impl) 732 else: 733 seeds_arr = jnp.asarray(seeds) --> 734 return random_seed_p.bind(seeds_arr, impl=impl)

File /mnt/data/miniconda/envs/energy_transformer_117/lib/python3.11/site-packages/jax/_src/core.py:380, in Primitive.bind(self, *args, *params) 377 def bind(self, args, **params): 378 assert (not config.jax_enable_checks or 379 all(isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args)), args --> 380 return self.bind_with_trace(find_top_trace(args), args, params)

File /mnt/data/miniconda/envs/energy_transformer_117/lib/python3.11/site-packages/jax/_src/core.py:383, in Primitive.bind_with_trace(self, trace, args, params) 382 def bind_with_trace(self, trace, args, params): --> 383 out = trace.process_primitive(self, map(trace.full_raise, args), params) 384 return map(full_lower, out) if self.multiple_results else full_lower(out)

File /mnt/data/miniconda/envs/energy_transformer_117/lib/python3.11/site-packages/jax/_src/core.py:790, in EvalTrace.process_primitive(self, primitive, tracers, params) 789 def process_primitive(self, primitive, tracers, params): --> 790 return primitive.impl(*tracers, **params)

File /mnt/data/miniconda/envs/energy_transformer_117/lib/python3.11/site-packages/jax/_src/prng.py:746, in random_seed_impl(seeds, impl) 744 @random_seed_p.def_impl 745 def random_seed_impl(seeds, *, impl): --> 746 base_arr = random_seed_impl_base(seeds, impl=impl) 747 return PRNGKeyArrayImpl(impl, base_arr)

File /mnt/data/miniconda/envs/energy_transformer_117/lib/python3.11/site-packages/jax/_src/prng.py:751, in random_seed_impl_base(seeds, impl) 749 def random_seed_impl_base(seeds, *, impl): 750 seed = iterated_vmap_unary(seeds.ndim, impl.seed) --> 751 return seed(seeds)

File /mnt/data/miniconda/envs/energy_transformer_117/lib/python3.11/site-packages/jax/_src/prng.py:980, in threefry_seed(seed) 968 def threefry_seed(seed: typing.Array) -> typing.Array: 969 """Create a single raw threefry PRNG key from an integer seed. 970 971 Args: (...) 978 first padding out with zeros). 979 """ --> 980 return _threefry_seed(seed)

[... skipping hidden 12 frame]

File /mnt/data/miniconda/envs/energy_transformer_117/lib/python3.11/site-packages/jax/_src/dispatch.py:463, in backend_compile(backend, module, options, host_callbacks) 458 return backend.compile(built_c, compile_options=options, 459 host_callbacks=host_callbacks) 460 # Some backends don't have host_callbacks option yet 461 # TODO(sharadmv): remove this fallback when all backends allow compile 462 # to take in host_callbacks --> 463 return backend.compile(built_c, compile_options=options)

XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.

What jax/jaxlib version are you using?

Jax0.4.10, jaxlib0.4.10+cuda11.cudnn86

Which accelerator(s) are you using?

GPU

Additional system info

Python 3.11.4, Ubuntu 22.04

NVIDIA GPU info

+-----------------------------------------------------------------------------+ | NVIDIA-SMI 520.61.05 Driver Version: 520.61.05 CUDA Version: 11.8 | |-------------------------------+----------------------+----------------------+ | GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC | | Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. | | | | MIG M. | |===============================+======================+======================| | 0 NVIDIA GeForce ... On | 00000000:18:00.0 Off | N/A | | 30% 33C P8 22W / 350W | 8688MiB / 12288MiB | 0% Default | | | | N/A | +-------------------------------+----------------------+----------------------+ | 1 NVIDIA GeForce ... On | 00000000:3B:00.0 Off | N/A | | 30% 31C P8 14W / 350W | 8MiB / 12288MiB | 0% Default | | | | N/A | +-------------------------------+----------------------+----------------------+ | 2 NVIDIA GeForce ... On | 00000000:86:00.0 Off | N/A | | 30% 34C P8 24W / 350W | 8MiB / 12288MiB | 0% Default | | | | N/A | +-------------------------------+----------------------+----------------------+ | 3 NVIDIA GeForce ... On | 00000000:AF:00.0 Off | N/A | | 30% 30C P8 8W / 350W | 8MiB / 12288MiB | 0% Default | | | | N/A | +-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+ | Processes: | | GPU GI CI PID Type Process name GPU Memory | | ID ID Usage | |=============================================================================| | 0 N/A N/A 2565 G /usr/lib/xorg/Xorg 4MiB | | 0 N/A N/A 2861 C+G ...ome-remote-desktop-daemon 249MiB | | 0 N/A N/A 3213874 C ...ransformer_117/bin/python 8430MiB | | 1 N/A N/A 2565 G /usr/lib/xorg/Xorg 4MiB | | 2 N/A N/A 2565 G /usr/lib/xorg/Xorg 4MiB | | 3 N/A N/A 2565 G /usr/lib/xorg/Xorg 4MiB | +-----------------------------------------------------------------------------+

hawkinsp commented 1 year ago

@TInaWangxue's problem was resolved in https://github.com/google/jax/issues/16901.

cloudinging commented 1 year ago

hi, I have similar issue. please help me!

the output: 2023-09-05 14:32:56.559501: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:439] Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR 2023-09-05 14:32:56.559528: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:443] Memory usage: 6081413120 bytes free, 25438126080 bytes total. Traceback (most recent call last): File "/home/wangyun/pre/alphafold-multimer-main/run_alphafold.py", line 453, in app.run(main) File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/absl/app.py", line 308, in run _run_main(main, args) File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/absl/app.py", line 254, in _run_main sys.exit(main(argv)) File "/home/wangyun/pre/alphafold-multimer-main/run_alphafold.py", line 428, in main predict_structure( File "/home/wangyun/pre/alphafold-multimer-main/run_alphafold.py", line 214, in predict_structure prediction_result = model_runner.predict(processed_feature_dict, File "/home/wangyun/pre/alphafold-multimer-main/alphafold/model/model.py", line 167, in predict result = self.apply(self.params, jax.random.PRNGKey(random_seed), feat) File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/jax/_src/random.py", line 137, in PRNGKey key = prng.seed_with_impl(impl, seed) File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/jax/_src/prng.py", line 320, in seed_with_impl return random_seed(seed, impl=impl) File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/jax/_src/prng.py", line 732, in random_seed return random_seed_p.bind(seeds_arr, impl=impl) File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/jax/_src/prng.py", line 744, in random_seed_impl base_arr = random_seed_impl_base(seeds, impl=impl) File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/jax/_src/prng.py", line 749, in random_seed_impl_base return seed(seeds) File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/jax/_src/prng.py", line 978, in threefry_seed return _threefry_seed(seed) File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback return fun(*args, kwargs) File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/jax/_src/pjit.py", line 208, in cache_miss outs, out_flat, out_tree, args_flat = _python_pjit_helper( File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/jax/_src/pjit.py", line 155, in _python_pjit_helper out_flat = pjit_p.bind(*args_flat, *params) File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/jax/_src/core.py", line 2633, in bind return self.bind_with_trace(top_trace, args, params) File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/jax/_src/core.py", line 383, in bind_with_trace out = trace.process_primitive(self, map(trace.full_raise, args), params) File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/jax/_src/core.py", line 790, in process_primitive return primitive.impl(tracers, params) File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/jax/_src/pjit.py", line 1085, in _pjit_call_impl compiled = _pjit_lower( File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/jax/_src/interpreters/pxla.py", line 2313, in compile executable = UnloadedMeshExecutable.from_hlo( File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/jax/_src/interpreters/pxla.py", line 2633, in from_hlo xla_executable, compile_options = _cached_compilation( File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/jax/_src/interpreters/pxla.py", line 2551, in _cached_compilation xla_executable = dispatch.compile_or_get_cached( File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/jax/_src/dispatch.py", line 494, in compile_or_get_cached return backend_compile(backend, computation, compile_options, File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/jax/_src/profiler.py", line 314, in wrapper return func(*args, **kwargs) File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/jax/_src/dispatch.py", line 462, in backend_compile return backend.compile(built_c, compile_options=options) jax._src.traceback_util.UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.

What jax/jaxlib version are you using? jax 0.4.9, jaxlib 0.4.9+cuda11.cudnn86

My conda virtual environment: python3.9.0 cudatoolkit 11.8.0 h4ba93d1_12 conda-forge cudnn 8.6.0.163 hed8a83a_0 cudistas

But my OS environment: NVIDIA-SMI 535.54.03 Driver Version: 535.54.03 CUDA Version: 12.2 ii cudnn-local-repo-ubuntu1804-8.9.3.28 1.0-1 amd64 cudnn-local repository configuration files

I try everything what I can , but ……

April-ppigg commented 1 year ago

What is your OS? Can you confirm you run the scripts sequentially and so there is nothing that is using the GPU in parallel?

Hi, I encountered the same problem. When I use A100 to run a single task, it can run normally, but when I submit two tasks at the same time, the above error will be reported. So the reason is that A100 runs two tasks at the same time, will there be a conflict?

nouiz commented 1 year ago

Hi, I encountered the same problem. When I use A100 to run a single task, it can run normally, but when I submit two tasks at the same time, the above error will be reported. So the reason is that A100 runs two tasks at the same time, will there be a conflict?

I suppose 2 tasks means 2 process. If not, tell us. By default, JAX will reserve 75% of the GPU memory for the process: https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html

So the 2nd process will end up missing GPU memory most of the time. Read that web page to know how to control that 75% memory allocation. If you can lower it to 45% and the first process has enough memory, it will probably work. Otherwise, try a few other values.

cloudinging commented 11 months ago

hi, I have similar issue. please help me!

the output: 2023-09-05 14:32:56.559501: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:439] Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR 2023-09-05 14:32:56.559528: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:443] Memory usage: 6081413120 bytes free, 25438126080 bytes total. Traceback (most recent call last): File "/home/wangyun/pre/alphafold-multimer-main/run_alphafold.py", line 453, in app.run(main) File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/absl/app.py", line 308, in run _run_main(main, args) File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/absl/app.py", line 254, in _run_main sys.exit(main(argv)) File "/home/wangyun/pre/alphafold-multimer-main/run_alphafold.py", line 428, in main predict_structure( File "/home/wangyun/pre/alphafold-multimer-main/run_alphafold.py", line 214, in predict_structure prediction_result = model_runner.predict(processed_feature_dict, File "/home/wangyun/pre/alphafold-multimer-main/alphafold/model/model.py", line 167, in predict result = self.apply(self.params, jax.random.PRNGKey(random_seed), feat) File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/jax/_src/random.py", line 137, in PRNGKey key = prng.seed_with_impl(impl, seed) File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/jax/_src/prng.py", line 320, in seed_with_impl return random_seed(seed, impl=impl) File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/jax/_src/prng.py", line 732, in random_seed return random_seed_p.bind(seeds_arr, impl=impl) File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/jax/_src/prng.py", line 744, in random_seed_impl base_arr = random_seed_impl_base(seeds, impl=impl) File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/jax/_src/prng.py", line 749, in random_seed_impl_base return seed(seeds) File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/jax/_src/prng.py", line 978, in threefry_seed return _threefry_seed(seed) File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback return fun(*args, kwargs) File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/jax/_src/pjit.py", line 208, in cache_miss outs, out_flat, out_tree, args_flat = _python_pjit_helper( File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/jax/_src/pjit.py", line 155, in _python_pjit_helper out_flat = pjit_p.bind(*args_flat, *params) File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/jax/_src/core.py", line 2633, in bind return self.bind_with_trace(top_trace, args, params) File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/jax/_src/core.py", line 383, in bind_with_trace out = trace.process_primitive(self, map(trace.full_raise, args), params) File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/jax/_src/core.py", line 790, in process_primitive return primitive.impl(tracers, params) File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/jax/_src/pjit.py", line 1085, in _pjit_call_impl compiled = _pjit_lower( File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/jax/_src/interpreters/pxla.py", line 2313, in compile executable = UnloadedMeshExecutable.from_hlo( File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/jax/_src/interpreters/pxla.py", line 2633, in from_hlo xla_executable, compile_options = _cached_compilation( File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/jax/_src/interpreters/pxla.py", line 2551, in _cached_compilation xla_executable = dispatch.compile_or_get_cached( File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/jax/_src/dispatch.py", line 494, in compile_or_get_cached return backend_compile(backend, computation, compile_options, File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/jax/_src/profiler.py", line 314, in wrapper return func(*args, **kwargs) File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/jax/_src/dispatch.py", line 462, in backend_compile return backend.compile(built_c, compile_options=options) jax._src.traceback_util.UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.

What jax/jaxlib version are you using? jax 0.4.9, jaxlib 0.4.9+cuda11.cudnn86

My conda virtual environment: python3.9.0 cudatoolkit 11.8.0 h4ba93d1_12 conda-forge cudnn 8.6.0.163 hed8a83a_0 cudistas

But my OS environment: NVIDIA-SMI 535.54.03 Driver Version: 535.54.03 CUDA Version: 12.2 ii cudnn-local-repo-ubuntu1804-8.9.3.28 1.0-1 amd64 cudnn-local repository configuration files

I try everything what I can , but ……

then, it work. you can look this link (https://blog.csdn.net/2201_75882736/article/details/132812927)

William-HYWu commented 6 months ago

Hi, I also have the same issue, could anyone please help me? E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:407] There was an error before creating cudnn handle (302): cudaGetErrorName symbol not found. : cudaGetErrorString symbol not found. Traceback (most recent call last): File "/bd_byt4090i1/users/state_space_model/DNN/S5/run_train.py", line 101, in train(parser.parse_args()) File "/bd_byt4090i1/users/state_space_model/DNN/S5/s5/train.py", line 41, in train key = random.PRNGKey(args.jax_seed) File "/bd_byt4090i1/users/state_space_model/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/random.py", line 160, in PRNGKey key = prng.seed_with_impl(impl, seed) File "/bd_byt4090i1/users/state_space_model/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/prng.py", line 406, in seed_with_impl return random_seed(seed, impl=impl) File "/bd_byt4090i1/users/state_space_model/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/prng.py", line 690, in random_seed return random_seed_p.bind(seeds_arr, impl=impl) File "/bd_byt4090i1/users/state_space_model/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/prng.py", line 702, in random_seed_impl base_arr = random_seed_impl_base(seeds, impl=impl) File "/bd_byt4090i1/users/state_space_model/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/prng.py", line 707, in random_seed_impl_base return seed(seeds) File "/bd_byt4090i1/users/state_space_model/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/prng.py", line 936, in threefry_seed return _threefry_seed(seed) File "/bd_byt4090i1/users/state_space_model/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback return fun(*args, kwargs) File "/bd_byt4090i1/users/state_space_model/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/pjit.py", line 250, in cache_miss outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper( File "/bd_byt4090i1/users/state_space_model/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/pjit.py", line 163, in _python_pjit_helper out_flat = pjit_p.bind(*args_flat, *params) File "/bd_byt4090i1/users/state_space_model/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/core.py", line 2677, in bind return self.bind_with_trace(top_trace, args, params) File "/bd_byt4090i1/users/state_space_model/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/core.py", line 383, in bind_with_trace out = trace.process_primitive(self, map(trace.full_raise, args), params) File "/bd_byt4090i1/users/state_space_model/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/core.py", line 815, in process_primitive return primitive.impl(tracers, params) File "/bd_byt4090i1/users/state_space_model/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/pjit.py", line 1203, in _pjit_call_impl return xc._xla.pjit(name, f, call_impl_cache_miss, [], [], donated_argnums, File "/bd_byt4090i1/users/state_space_model/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/pjit.py", line 1187, in call_impl_cache_miss out_flat, compiled = _pjit_call_impl_python( File "/bd_byt4090i1/users/state_space_model/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/pjit.py", line 1123, in _pjit_call_impl_python always_lower=False, lowering_platform=None).compile() File "/bd_byt4090i1/users/state_space_model/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 2323, in compile executable = UnloadedMeshExecutable.from_hlo( File "/bd_byt4090i1/users/state_space_model/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 2645, in from_hlo xla_executable, compile_options = _cached_compilation( File "/bd_byt4090i1/users/state_space_model/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 2555, in _cached_compilation xla_executable = dispatch.compile_or_get_cached( File "/bd_byt4090i1/users/state_space_model/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/dispatch.py", line 497, in compile_or_get_cached return backend_compile(backend, computation, compile_options, File "/bd_byt4090i1/users/state_space_model/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/profiler.py", line 314, in wrapper return func(*args, **kwargs) File "/bd_byt4090i1/users/state_space_model/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/dispatch.py", line 465, in backend_compile return backend.compile(built_c, compile_options=options) jax._src.traceback_util.UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.

The stack trace below excludes JAX-internal frames. The preceding is the original exception that occurred, unmodified.


The above exception was the direct cause of the following exception:

Traceback (most recent call last): File "/bd_byt4090i1/users/state_space_model/DNN/S5/run_train.py", line 101, in train(parser.parse_args()) File "/bd_byt4090i1/users/state_space_model/DNN/S5/s5/train.py", line 41, in train key = random.PRNGKey(args.jax_seed) File "/bd_byt4090i1/users/state_space_model/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/random.py", line 160, in PRNGKey key = prng.seed_with_impl(impl, seed) File "/bd_byt4090i1/users/state_space_model/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/prng.py", line 406, in seed_with_impl return random_seed(seed, impl=impl) File "/bd_byt4090i1/users/state_space_model/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/prng.py", line 690, in random_seed return random_seed_p.bind(seeds_arr, impl=impl) File "/bd_byt4090i1/users/state_space_model/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/core.py", line 380, in bind return self.bind_with_trace(find_top_trace(args), args, params) File "/bd_byt4090i1/users/state_space_model/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/core.py", line 383, in bind_with_trace out = trace.process_primitive(self, map(trace.full_raise, args), params) File "/bd_byt4090i1/users/state_space_model/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/core.py", line 815, in process_primitive return primitive.impl(*tracers, **params) File "/bd_byt4090i1/users/state_space_model/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/prng.py", line 702, in random_seed_impl base_arr = random_seed_impl_base(seeds, impl=impl) File "/bd_byt4090i1/users/state_space_model/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/prng.py", line 707, in random_seed_impl_base return seed(seeds) File "/bd_byt4090i1/users/state_space_model/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/prng.py", line 936, in threefry_seed return _threefry_seed(seed) jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.

My jax version: jax==0.4.13 jaxlib==0.4.13+cuda11.cudnn86 flax==0.7.4 chex==0.1.8

My gpu information: +-----------------------------------------------------------------------------+ | NVIDIA-SMI 470.223.02 Driver Version: 470.223.02 CUDA Version: 11.4 | |-------------------------------+----------------------+----------------------+ | GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC | | Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. | | | | MIG M. | |===============================+======================+======================| | 0 NVIDIA GeForce ... Off | 00000000:83:00.0 Off | N/A | | 21% 35C P0 40W / 215W | 0MiB / 7982MiB | 0% Default | | | | N/A | +-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+ | Processes: | | GPU GI CI PID Type Process name GPU Memory | | ID ID Usage | |=============================================================================| | No running processes found | +-----------------------------------------------------------------------------+

It should be a RTX 2070. Thanks a lot for the help

nouiz commented 6 months ago

You are using an old JAX version (0.4.13) and an old driver that is for CUDA 11.8. Can you update your NVIDIA driver to at least one that support CUDA 11.8, as this is the min version that is currently supported by JAX? JAX is dropping CUDA 11 in the next releases, so if you can update to CUDA12, that would be better.

William-HYWu commented 6 months ago

You are using an old JAX version (0.4.13) and an old driver that is for CUDA 11.8. Can you update your NVIDIA driver to at least one that support CUDA 11.8, as this is the min version that is currently supported by JAX? JAX is dropping CUDA 11 in the next releases, so if you can update to CUDA12, that would be better.

Thanks a lot. I tried that, and it worked!

crshin commented 5 months ago

Hi, guys. I'm here on by recommendation. I'm facing a similar issue: "Not Enough GPU memory? FAILED_PRECONDITION: DNN library initialization failed." I tried almost everything suggested on these page to resolve the GPU memory problem: https://github.com/YoshitakaMo/localcolabfold/issues/210 https://github.com/YoshitakaMo/localcolabfold/issues/224 https://github.com/YoshitakaMo/localcolabfold/issues/228

My current jax, cudnn, and nvidia-smi versions are as follows. (Linux Ubuntu 22.04.2 LTS and RTX 4090)

$nvcc --version
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2024 NVIDIA Corporation
Built on Tue_Feb_27_16:19:38_PST_2024
Cuda compilation tools, release 12.4, V12.4.99
Build cuda_12.4.r12.4/compiler.33961263_0

$python3.10 -m pip list | grep jax
jax                      0.4.23
jax-cuda12-pjrt          0.4.23
jax-cuda12-plugin        0.4.23
jaxlib                   0.4.23+cuda12.cudnn89

$python3.10 -m pip list | grep cudnn
jaxlib                   0.4.23+cuda12.cudnn89
nvidia-cudnn-cu12        9.1.0.70

$nvidia-smi
Fri May 10 07:20:32 2024
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.14              Driver Version: 550.54.14      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+

Here's the problem I'm encountering:

$colabfold_batch --templates --amber test_A.fasta ./
2024-05-13 06:17:37,861 Running colabfold 1.5.5 (57b220e028610ba7331ebe1ef9c2d0419992469a)
2024-05-13 06:17:38,189 Running on GPU
2024-05-13 06:17:38,738 Found 9 citations for tools or databases
2024-05-13 06:17:38,738 Query 1/1: pdb_A (length 108)
2024-05-13 06:17:41,729 Sequence 0 found templates: ['1m4u_L', '2r52_B', '6oml_Y', '5vt2_B', '2r53_A', '1lxi_A', '4n1d_A', '7zjf_B', '7zjf_A', '6z3g_A', '3qb4_C', '3qb4_A', '6z3j_A', '2h64_A', '4uhy_A', '1reu_A', '4ui0_A', '2h62_B', '4mid_A', '3bk3_B']
2024-05-13 06:17:42,533 Setting max_seq=512, max_extra_seq=5120
2024-05-13 06:17:42,674 Could not predict pdb_A. Not Enough GPU memory? FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.
2024-05-13 06:17:42,674 Done

I can't find "the errors above."

Can someone offer any guesses or solutions?

hawkinsp commented 5 months ago

JAX releases don't support CUDNN 9 yet. Downgrade to CUDNN 8.9 (or build jaxlib from source with CUDNN 9, that works also)

AWangji commented 5 months ago

I got the same error, maybe it due to the mismatch between your cuda version and the installed jax. I use ubuntu 20.04 with cuda version as below: image At start, I installed the newest jax as

pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

Then I got the error as reported. So I switched to

pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

Everything works well!

but jaxlib has no cuda11 image

Lvchangze commented 3 months ago

I got the same error, maybe it due to the mismatch between your cuda version and the installed jax. I use ubuntu 20.04 with cuda version as below: image At start, I installed the newest jax as

pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

Then I got the error as reported. So I switched to

pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

Everything works well!

Great!

agnikumar commented 1 month ago

I'm getting jaxlib.xla_extension.XlaRuntimeError: FAILED PRECONDITION: DNN library initialization failed.

Checking if there are any ideas about how to resolve this? JAX version is 0.4.26, CUDA version is 12.4, and cuDNN version is 8.9.7.29, which should be compatible. image