Closed Toni-SM closed 1 year ago
What is your OS? Can you confirm you run the scripts sequentially and so there is nothing that is using the GPU in parallel?
Hi @nouiz
The OS is Ubuntu 20.04, as indicated above.
Btw, I think the problem may be VS Code. After running the script several times to try to get the error to appear, I see that the error only appears (not always but) when I make a modification to the script and save it.
There is also the following log.
As you can see (by running the nvidia-smi
command just before executing the script, and after saving it) there is a GPU consumption. The strange thing is that the consumption comes from the python environment (env_gym
) configured in the VS Code bottom right pane and not from the python of the sourced environment where jax is installed (env_jax
) 🤔
(env_jax) toni@HP-ZBook-Studio-G8:~/Documents/SKRL/skrl/docs/source/examples/gymnasium$ nvidia-smi; python ddpg_jax_gymnasium_pendulum.py
Mon Apr 3 20:41:40 2023
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 530.30.02 Driver Version: 530.30.02 CUDA Version: 12.1 |
|-----------------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+======================+======================|
| 0 NVIDIA GeForce RTX 3080 L... On | 00000000:01:00.0 Off | N/A |
| N/A 49C P3 23W / 55W| 12453MiB / 16384MiB | 4% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
+---------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=======================================================================================|
| 0 N/A N/A 1536 G /usr/lib/xorg/Xorg 4MiB |
| 0 N/A N/A 2456 G /usr/lib/xorg/Xorg 4MiB |
| 0 N/A N/A 27050 C .../SKRL/envs/env_gym/bin/python 12440MiB |
+---------------------------------------------------------------------------------------+
[skrl:INFO] Environment class: gymnasium.core.Wrapper, gymnasium.utils.record_constructor.RecordConstructorArgs
[skrl:INFO] Environment wrapper: Gymnasium
2023-04-03 20:41:43.310989: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:429] Could not create cudnn handle: CUDNN_STATUS_NOT_INITIALIZED
2023-04-03 20:41:43.311060: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:438] Possibly insufficient driver version: 530.30.2
Traceback (most recent call last):
...
Thanks for the results. I think we need a way to give a better error to end users.
Recently a few error message got a little bit better. Closing as I'm not sure what do to more. But if the issue appear again and the error isn't good enough, poke us again.
I also got this error and it was due to GPU reaching its memory limit
@amacrutherford Do you have the full error message you had? I would like to improve the error message in that case.
same error
---------------------------------------------------------------------------
XlaRuntimeError Traceback (most recent call last)
[<ipython-input-25-5a28263ee724>](https://localhost:8080/#) in <cell line: 5>()
3
4 # Initialize model weights using dummy tensors.
----> 5 rng = jax.random.PRNGKey(0)
6 rng, key = jax.random.split(rng)
7 init_img = jnp.ones((4, 224, 224, 5), jnp.float32)
22 frames
[/usr/local/lib/python3.10/dist-packages/jax/_src/dispatch.py](https://localhost:8080/#) in backend_compile(backend, built_c, options, host_callbacks)
469 # TODO(sharadmv): remove this fallback when all backends allow `compile`
470 # to take in `host_callbacks`
--> 471 return backend.compile(built_c, compile_options=options)
472
473 _ir_dump_counter = itertools.count()
XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.
I think there also the memory reaching the limit.
Fri May 12 01:50:33 2023
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.85.12 Driver Version: 525.85.12 CUDA Version: 12.0 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|===============================+======================+======================|
| 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 |
| N/A 57C P0 27W / 70W | 15101MiB / 15360MiB | 0% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
+-----------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=============================================================================|
+-----------------------------------------------------------------------------+
Yep I received the same error message as @Bailey-24
Thanks for the extended error message. But can you share the full output without any truncation? There should be information that should help me. Which JAX version did you use?
I'm having the same problem but for me it's consistent and I'm unable to run simple Jax code. I only have this problem on my newest system with 4x RTX 4090 GPUs. I have a server A100 and a PC with a 3090ti that work smoothly. Ubuntu 22 across all systems. First installed CUDA 11 from conda-forge as suggested, same issue. Then switched to loca installation of CUDA and cudnn. Same problem.
After a fresh installation of everything when I run a = jnp.ones((3,))
I get this error:
2023-05-13 09:04:27.790057: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:439] Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR
2023-05-13 09:04:27.790140: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:443] Memory usage: 5853872128 bytes free, 25393692672 bytes total.
Traceback (most recent call last):
File "
I've also tried to swap my PCs 3090ti (which works properly) with one of the 4090s and I got the exact same error. This used to work in the past. I'm pretty sure the GPU hardware is not the problem and they are functional (tested them on Windows machines).
Did you cut the output? The error tell to look above for more errors. If there is more outputs, give me all what you have. I'll filter what is useful or not.
I'm getting the same kind of error trying to install jax/jaxlib on an EC2 p2.xlarge (with k80s), to provide solidarity! I can provide more details if useful, but basically running some vanilla installation script of Anaconda and trying different variants of pip install "jax[cuda11_cudnn82]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
leads Jax to report seeing the GPU when I check print(xla_bridge.get_backend().platform)
but gives the DNN error above, otherwise.
(I'm also unable to any Jax code, e.g. a = jnp.ones((3,))
.)
@ampolloreno Please open new issues rather than appending onto closed ones.
However, I think the problem in your case is simple: JAX no longer supports Kepler GPUs in the wheels we release. You can probably rebuild jaxlib from source if you need Kepler support, but note NVIDIA has dropped Kepler support from CUDA 12 and CUDNN 8.9, so this may not remain true for long.
@nouiz no this is all the output. It's several lines of error are you seeing all of it?
I managed to resolve this though. I installed CUDA 11 and cudnn 8.6. In my experiments I also installed the latest version of everything but this was the only version combination that worked for me. Now I'm getting other errors but that's a different problem.
@hawkinsp Point taken and... Thanks for the help! I just switched over to V100s and voila!
I got the same error, maybe it due to the mismatch between your cuda version and the installed jax. I use ubuntu 20.04 with cuda version as below: At start, I installed the newest jax as
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Then I got the error as reported. So I switched to
pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Everything works well!
Thanks @hosseybposh, for a simple use case I was able to use JAX 0.4.13 and CUDA 11.8 with CUDNN 8.6. I needed to add /usr/lib/x86_64-linux-gnu
to the LD_LIBRARY_PATH
(installed libcudnn8
with apt-get
).
XlaRuntimeError Traceback (most recent call last) Cell In[4], line 29 26 model = trainer.make_model(nmask) 28 lr_fn, opt = trainer.make_optimizer(steps_per_epoch=len(train_dl)) ---> 29 state = trainer.create_train_state(jax.random.PRNGKey(0), model, opt) 30 state = checkpoints.restore_checkpoint(ckpt.parent, state)
File /mnt/data/miniconda/envs/energy_transformer_117/lib/python3.11/site-packages/jax/_src/random.py:137, in PRNGKey(seed) 134 if np.ndim(seed): 135 raise TypeError("PRNGKey accepts a scalar seed, but was given an array of" 136 f"shape {np.shape(seed)} != (). Use jax.vmap for batching") --> 137 key = prng.seed_with_impl(impl, seed) 138 return _return_prng_keys(True, key)
File /mnt/data/miniconda/envs/energy_transformer_117/lib/python3.11/site-packages/jax/_src/prng.py:320, in seed_with_impl(impl, seed) 319 def seed_with_impl(impl: PRNGImpl, seed: Union[int, Array]) -> PRNGKeyArrayImpl: --> 320 return random_seed(seed, impl=impl)
File /mnt/data/miniconda/envs/energy_transformer_117/lib/python3.11/site-packages/jax/_src/prng.py:734, in random_seed(seeds, impl) 732 else: 733 seeds_arr = jnp.asarray(seeds) --> 734 return random_seed_p.bind(seeds_arr, impl=impl)
File /mnt/data/miniconda/envs/energy_transformer_117/lib/python3.11/site-packages/jax/_src/core.py:380, in Primitive.bind(self, *args, *params) 377 def bind(self, args, **params): 378 assert (not config.jax_enable_checks or 379 all(isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args)), args --> 380 return self.bind_with_trace(find_top_trace(args), args, params)
File /mnt/data/miniconda/envs/energy_transformer_117/lib/python3.11/site-packages/jax/_src/core.py:383, in Primitive.bind_with_trace(self, trace, args, params) 382 def bind_with_trace(self, trace, args, params): --> 383 out = trace.process_primitive(self, map(trace.full_raise, args), params) 384 return map(full_lower, out) if self.multiple_results else full_lower(out)
File /mnt/data/miniconda/envs/energy_transformer_117/lib/python3.11/site-packages/jax/_src/core.py:790, in EvalTrace.process_primitive(self, primitive, tracers, params) 789 def process_primitive(self, primitive, tracers, params): --> 790 return primitive.impl(*tracers, **params)
File /mnt/data/miniconda/envs/energy_transformer_117/lib/python3.11/site-packages/jax/_src/prng.py:746, in random_seed_impl(seeds, impl) 744 @random_seed_p.def_impl 745 def random_seed_impl(seeds, *, impl): --> 746 base_arr = random_seed_impl_base(seeds, impl=impl) 747 return PRNGKeyArrayImpl(impl, base_arr)
File /mnt/data/miniconda/envs/energy_transformer_117/lib/python3.11/site-packages/jax/_src/prng.py:751, in random_seed_impl_base(seeds, impl) 749 def random_seed_impl_base(seeds, *, impl): 750 seed = iterated_vmap_unary(seeds.ndim, impl.seed) --> 751 return seed(seeds)
File /mnt/data/miniconda/envs/energy_transformer_117/lib/python3.11/site-packages/jax/_src/prng.py:980, in threefry_seed(seed) 968 def threefry_seed(seed: typing.Array) -> typing.Array: 969 """Create a single raw threefry PRNG key from an integer seed. 970 971 Args: (...) 978 first padding out with zeros). 979 """ --> 980 return _threefry_seed(seed)
[... skipping hidden 12 frame]
File /mnt/data/miniconda/envs/energy_transformer_117/lib/python3.11/site-packages/jax/_src/dispatch.py:463, in backend_compile(backend, module, options, host_callbacks)
458 return backend.compile(built_c, compile_options=options,
459 host_callbacks=host_callbacks)
460 # Some backends don't have host_callbacks
option yet
461 # TODO(sharadmv): remove this fallback when all backends allow compile
462 # to take in host_callbacks
--> 463 return backend.compile(built_c, compile_options=options)
XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.
Jax0.4.10, jaxlib0.4.10+cuda11.cudnn86
GPU
Python 3.11.4, Ubuntu 22.04
+-----------------------------------------------------------------------------+ | NVIDIA-SMI 520.61.05 Driver Version: 520.61.05 CUDA Version: 11.8 | |-------------------------------+----------------------+----------------------+ | GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC | | Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. | | | | MIG M. | |===============================+======================+======================| | 0 NVIDIA GeForce ... On | 00000000:18:00.0 Off | N/A | | 30% 33C P8 22W / 350W | 8688MiB / 12288MiB | 0% Default | | | | N/A | +-------------------------------+----------------------+----------------------+ | 1 NVIDIA GeForce ... On | 00000000:3B:00.0 Off | N/A | | 30% 31C P8 14W / 350W | 8MiB / 12288MiB | 0% Default | | | | N/A | +-------------------------------+----------------------+----------------------+ | 2 NVIDIA GeForce ... On | 00000000:86:00.0 Off | N/A | | 30% 34C P8 24W / 350W | 8MiB / 12288MiB | 0% Default | | | | N/A | +-------------------------------+----------------------+----------------------+ | 3 NVIDIA GeForce ... On | 00000000:AF:00.0 Off | N/A | | 30% 30C P8 8W / 350W | 8MiB / 12288MiB | 0% Default | | | | N/A | +-------------------------------+----------------------+----------------------+
+-----------------------------------------------------------------------------+ | Processes: | | GPU GI CI PID Type Process name GPU Memory | | ID ID Usage | |=============================================================================| | 0 N/A N/A 2565 G /usr/lib/xorg/Xorg 4MiB | | 0 N/A N/A 2861 C+G ...ome-remote-desktop-daemon 249MiB | | 0 N/A N/A 3213874 C ...ransformer_117/bin/python 8430MiB | | 1 N/A N/A 2565 G /usr/lib/xorg/Xorg 4MiB | | 2 N/A N/A 2565 G /usr/lib/xorg/Xorg 4MiB | | 3 N/A N/A 2565 G /usr/lib/xorg/Xorg 4MiB | +-----------------------------------------------------------------------------+
@TInaWangxue's problem was resolved in https://github.com/google/jax/issues/16901.
hi, I have similar issue. please help me!
the output:
2023-09-05 14:32:56.559501: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:439] Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR
2023-09-05 14:32:56.559528: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:443] Memory usage: 6081413120 bytes free, 25438126080 bytes total.
Traceback (most recent call last):
File "/home/wangyun/pre/alphafold-multimer-main/run_alphafold.py", line 453, in
What jax/jaxlib version are you using? jax 0.4.9, jaxlib 0.4.9+cuda11.cudnn86
My conda virtual environment: python3.9.0 cudatoolkit 11.8.0 h4ba93d1_12 conda-forge cudnn 8.6.0.163 hed8a83a_0 cudistas
But my OS environment: NVIDIA-SMI 535.54.03 Driver Version: 535.54.03 CUDA Version: 12.2 ii cudnn-local-repo-ubuntu1804-8.9.3.28 1.0-1 amd64 cudnn-local repository configuration files
I try everything what I can , but ……
What is your OS? Can you confirm you run the scripts sequentially and so there is nothing that is using the GPU in parallel?
Hi, I encountered the same problem. When I use A100 to run a single task, it can run normally, but when I submit two tasks at the same time, the above error will be reported. So the reason is that A100 runs two tasks at the same time, will there be a conflict?
Hi, I encountered the same problem. When I use A100 to run a single task, it can run normally, but when I submit two tasks at the same time, the above error will be reported. So the reason is that A100 runs two tasks at the same time, will there be a conflict?
I suppose 2 tasks means 2 process. If not, tell us. By default, JAX will reserve 75% of the GPU memory for the process: https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html
So the 2nd process will end up missing GPU memory most of the time. Read that web page to know how to control that 75% memory allocation. If you can lower it to 45% and the first process has enough memory, it will probably work. Otherwise, try a few other values.
hi, I have similar issue. please help me!
the output: 2023-09-05 14:32:56.559501: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:439] Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR 2023-09-05 14:32:56.559528: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:443] Memory usage: 6081413120 bytes free, 25438126080 bytes total. Traceback (most recent call last): File "/home/wangyun/pre/alphafold-multimer-main/run_alphafold.py", line 453, in app.run(main) File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/absl/app.py", line 308, in run _run_main(main, args) File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/absl/app.py", line 254, in _run_main sys.exit(main(argv)) File "/home/wangyun/pre/alphafold-multimer-main/run_alphafold.py", line 428, in main predict_structure( File "/home/wangyun/pre/alphafold-multimer-main/run_alphafold.py", line 214, in predict_structure prediction_result = model_runner.predict(processed_feature_dict, File "/home/wangyun/pre/alphafold-multimer-main/alphafold/model/model.py", line 167, in predict result = self.apply(self.params, jax.random.PRNGKey(random_seed), feat) File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/jax/_src/random.py", line 137, in PRNGKey key = prng.seed_with_impl(impl, seed) File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/jax/_src/prng.py", line 320, in seed_with_impl return random_seed(seed, impl=impl) File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/jax/_src/prng.py", line 732, in random_seed return random_seed_p.bind(seeds_arr, impl=impl) File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/jax/_src/prng.py", line 744, in random_seed_impl base_arr = random_seed_impl_base(seeds, impl=impl) File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/jax/_src/prng.py", line 749, in random_seed_impl_base return seed(seeds) File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/jax/_src/prng.py", line 978, in threefry_seed return _threefry_seed(seed) File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback return fun(*args, kwargs) File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/jax/_src/pjit.py", line 208, in cache_miss outs, out_flat, out_tree, args_flat = _python_pjit_helper( File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/jax/_src/pjit.py", line 155, in _python_pjit_helper out_flat = pjit_p.bind(*args_flat, *params) File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/jax/_src/core.py", line 2633, in bind return self.bind_with_trace(top_trace, args, params) File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/jax/_src/core.py", line 383, in bind_with_trace out = trace.process_primitive(self, map(trace.full_raise, args), params) File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/jax/_src/core.py", line 790, in process_primitive return primitive.impl(tracers, params) File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/jax/_src/pjit.py", line 1085, in _pjit_call_impl compiled = _pjit_lower( File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/jax/_src/interpreters/pxla.py", line 2313, in compile executable = UnloadedMeshExecutable.from_hlo( File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/jax/_src/interpreters/pxla.py", line 2633, in from_hlo xla_executable, compile_options = _cached_compilation( File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/jax/_src/interpreters/pxla.py", line 2551, in _cached_compilation xla_executable = dispatch.compile_or_get_cached( File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/jax/_src/dispatch.py", line 494, in compile_or_get_cached return backend_compile(backend, computation, compile_options, File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/jax/_src/profiler.py", line 314, in wrapper return func(*args, **kwargs) File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/jax/_src/dispatch.py", line 462, in backend_compile return backend.compile(built_c, compile_options=options) jax._src.traceback_util.UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.
What jax/jaxlib version are you using? jax 0.4.9, jaxlib 0.4.9+cuda11.cudnn86
My conda virtual environment: python3.9.0 cudatoolkit 11.8.0 h4ba93d1_12 conda-forge cudnn 8.6.0.163 hed8a83a_0 cudistas
But my OS environment: NVIDIA-SMI 535.54.03 Driver Version: 535.54.03 CUDA Version: 12.2 ii cudnn-local-repo-ubuntu1804-8.9.3.28 1.0-1 amd64 cudnn-local repository configuration files
I try everything what I can , but ……
then, it work. you can look this link (https://blog.csdn.net/2201_75882736/article/details/132812927)
Hi, I also have the same issue, could anyone please help me?
E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:407] There was an error before creating cudnn handle (302): cudaGetErrorName symbol not found. : cudaGetErrorString symbol not found.
Traceback (most recent call last):
File "/bd_byt4090i1/users/state_space_model/DNN/S5/run_train.py", line 101, in
The stack trace below excludes JAX-internal frames. The preceding is the original exception that occurred, unmodified.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/bd_byt4090i1/users/state_space_model/DNN/S5/run_train.py", line 101, in
My jax version: jax==0.4.13 jaxlib==0.4.13+cuda11.cudnn86 flax==0.7.4 chex==0.1.8
My gpu information: +-----------------------------------------------------------------------------+ | NVIDIA-SMI 470.223.02 Driver Version: 470.223.02 CUDA Version: 11.4 | |-------------------------------+----------------------+----------------------+ | GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC | | Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. | | | | MIG M. | |===============================+======================+======================| | 0 NVIDIA GeForce ... Off | 00000000:83:00.0 Off | N/A | | 21% 35C P0 40W / 215W | 0MiB / 7982MiB | 0% Default | | | | N/A | +-------------------------------+----------------------+----------------------+
+-----------------------------------------------------------------------------+ | Processes: | | GPU GI CI PID Type Process name GPU Memory | | ID ID Usage | |=============================================================================| | No running processes found | +-----------------------------------------------------------------------------+
It should be a RTX 2070. Thanks a lot for the help
You are using an old JAX version (0.4.13) and an old driver that is for CUDA 11.8. Can you update your NVIDIA driver to at least one that support CUDA 11.8, as this is the min version that is currently supported by JAX? JAX is dropping CUDA 11 in the next releases, so if you can update to CUDA12, that would be better.
You are using an old JAX version (0.4.13) and an old driver that is for CUDA 11.8. Can you update your NVIDIA driver to at least one that support CUDA 11.8, as this is the min version that is currently supported by JAX? JAX is dropping CUDA 11 in the next releases, so if you can update to CUDA12, that would be better.
Thanks a lot. I tried that, and it worked!
Hi, guys. I'm here on by recommendation. I'm facing a similar issue: "Not Enough GPU memory? FAILED_PRECONDITION: DNN library initialization failed." I tried almost everything suggested on these page to resolve the GPU memory problem: https://github.com/YoshitakaMo/localcolabfold/issues/210 https://github.com/YoshitakaMo/localcolabfold/issues/224 https://github.com/YoshitakaMo/localcolabfold/issues/228
My current jax, cudnn, and nvidia-smi versions are as follows. (Linux Ubuntu 22.04.2 LTS and RTX 4090)
$nvcc --version
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2024 NVIDIA Corporation
Built on Tue_Feb_27_16:19:38_PST_2024
Cuda compilation tools, release 12.4, V12.4.99
Build cuda_12.4.r12.4/compiler.33961263_0
$python3.10 -m pip list | grep jax
jax 0.4.23
jax-cuda12-pjrt 0.4.23
jax-cuda12-plugin 0.4.23
jaxlib 0.4.23+cuda12.cudnn89
$python3.10 -m pip list | grep cudnn
jaxlib 0.4.23+cuda12.cudnn89
nvidia-cudnn-cu12 9.1.0.70
$nvidia-smi
Fri May 10 07:20:32 2024
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.14 Driver Version: 550.54.14 CUDA Version: 12.4 |
|-----------------------------------------+------------------------+----------------------+
Here's the problem I'm encountering:
$colabfold_batch --templates --amber test_A.fasta ./
2024-05-13 06:17:37,861 Running colabfold 1.5.5 (57b220e028610ba7331ebe1ef9c2d0419992469a)
2024-05-13 06:17:38,189 Running on GPU
2024-05-13 06:17:38,738 Found 9 citations for tools or databases
2024-05-13 06:17:38,738 Query 1/1: pdb_A (length 108)
2024-05-13 06:17:41,729 Sequence 0 found templates: ['1m4u_L', '2r52_B', '6oml_Y', '5vt2_B', '2r53_A', '1lxi_A', '4n1d_A', '7zjf_B', '7zjf_A', '6z3g_A', '3qb4_C', '3qb4_A', '6z3j_A', '2h64_A', '4uhy_A', '1reu_A', '4ui0_A', '2h62_B', '4mid_A', '3bk3_B']
2024-05-13 06:17:42,533 Setting max_seq=512, max_extra_seq=5120
2024-05-13 06:17:42,674 Could not predict pdb_A. Not Enough GPU memory? FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.
2024-05-13 06:17:42,674 Done
I can't find "the errors above."
Can someone offer any guesses or solutions?
JAX releases don't support CUDNN 9 yet. Downgrade to CUDNN 8.9 (or build jaxlib from source with CUDNN 9, that works also)
I got the same error, maybe it due to the mismatch between your cuda version and the installed jax. I use ubuntu 20.04 with cuda version as below: At start, I installed the newest jax as
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Then I got the error as reported. So I switched to
pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Everything works well!
but jaxlib has no cuda11
I got the same error, maybe it due to the mismatch between your cuda version and the installed jax. I use ubuntu 20.04 with cuda version as below: At start, I installed the newest jax as
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Then I got the error as reported. So I switched to
pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Everything works well!
Great!
I'm getting
jaxlib.xla_extension.XlaRuntimeError: FAILED PRECONDITION: DNN library initialization failed.
Checking if there are any ideas about how to resolve this? JAX version is 0.4.26
, CUDA version is 12.4
, and cuDNN version is 8.9.7.29
, which should be compatible.
Description
I have a python virtual environment with a clean installation of JAX
When I run my scripts, they work perfectly, but sometimes I get the following error with a success rate of between 2 and 10 successful executions and between 1 and 3 failed executions
What jax/jaxlib version are you using?
jax 0.4.8, jaxlib 0.4.7+cuda12.cudnn88
Which accelerator(s) are you using?
GPU
Additional system info
Python 3.8.10, Ubuntu 20.04
NVIDIA GPU info
CUDNN version (
/usr/local/cuda/include/cudnn_version.h
)