google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.76k stars 2.72k forks source link

jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed #22626

Closed NonsansWD closed 1 month ago

NonsansWD commented 1 month ago

Description

Hello im having an issue using jax and cuda. My error message looks like this:

2024-07-19 15:49:01.341533: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:439] Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR 2024-07-19 15:49:01.341571: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:443] Memory usage: 2624061440 bytes free, 12524191744 bytes total. Traceback (most recent call last): File "scripts/train_vqgan.py", line 202, in main() File "scripts/train_vqgan.py", line 32, in main rng = jax.random.PRNGKey(config.seed) File "/home/nonsans/miniforge3/envs/viper/lib/python3.8/site-packages/jax/_src/random.py", line 160, in PRNGKey key = prng.seed_with_impl(impl, seed) File "/home/nonsans/miniforge3/envs/viper/lib/python3.8/site-packages/jax/_src/prng.py", line 406, in seed_with_impl return random_seed(seed, impl=impl) File "/home/nonsans/miniforge3/envs/viper/lib/python3.8/site-packages/jax/_src/prng.py", line 690, in random_seed return random_seed_p.bind(seeds_arr, impl=impl) File "/home/nonsans/miniforge3/envs/viper/lib/python3.8/site-packages/jax/_src/prng.py", line 702, in random_seed_impl base_arr = random_seed_impl_base(seeds, impl=impl) File "/home/nonsans/miniforge3/envs/viper/lib/python3.8/site-packages/jax/_src/prng.py", line 707, in random_seed_impl_base return seed(seeds) File "/home/nonsans/miniforge3/envs/viper/lib/python3.8/site-packages/jax/_src/prng.py", line 936, in threefry_seed return _threefry_seed(seed) File "/home/nonsans/miniforge3/envs/viper/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback return fun(*args, kwargs) File "/home/nonsans/miniforge3/envs/viper/lib/python3.8/site-packages/jax/_src/pjit.py", line 250, in cache_miss outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper( File "/home/nonsans/miniforge3/envs/viper/lib/python3.8/site-packages/jax/_src/pjit.py", line 163, in _python_pjit_helper out_flat = pjit_p.bind(*args_flat, *params) File "/home/nonsans/miniforge3/envs/viper/lib/python3.8/site-packages/jax/_src/core.py", line 2677, in bind return self.bind_with_trace(top_trace, args, params) File "/home/nonsans/miniforge3/envs/viper/lib/python3.8/site-packages/jax/_src/core.py", line 383, in bind_with_trace out = trace.process_primitive(self, map(trace.full_raise, args), params) File "/home/nonsans/miniforge3/envs/viper/lib/python3.8/site-packages/jax/_src/core.py", line 815, in process_primitive return primitive.impl(tracers, params) File "/home/nonsans/miniforge3/envs/viper/lib/python3.8/site-packages/jax/_src/pjit.py", line 1203, in _pjit_call_impl return xc._xla.pjit(name, f, call_impl_cache_miss, [], [], donated_argnums, File "/home/nonsans/miniforge3/envs/viper/lib/python3.8/site-packages/jax/_src/pjit.py", line 1187, in call_impl_cache_miss out_flat, compiled = _pjit_call_impl_python( File "/home/nonsans/miniforge3/envs/viper/lib/python3.8/site-packages/jax/_src/pjit.py", line 1120, in _pjit_call_impl_python compiled = _pjit_lower( File "/home/nonsans/miniforge3/envs/viper/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 2323, in compile executable = UnloadedMeshExecutable.from_hlo( File "/home/nonsans/miniforge3/envs/viper/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 2645, in from_hlo xla_executable, compile_options = _cached_compilation( File "/home/nonsans/miniforge3/envs/viper/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 2555, in _cached_compilation xla_executable = dispatch.compile_or_get_cached( File "/home/nonsans/miniforge3/envs/viper/lib/python3.8/site-packages/jax/_src/dispatch.py", line 497, in compile_or_get_cached return backend_compile(backend, computation, compile_options, File "/home/nonsans/miniforge3/envs/viper/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper return func(*args, **kwargs) File "/home/nonsans/miniforge3/envs/viper/lib/python3.8/site-packages/jax/_src/dispatch.py", line 465, in backend_compile return backend.compile(built_c, compile_options=options) jax._src.traceback_util.UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.

The stack trace below excludes JAX-internal frames. The preceding is the original exception that occurred, unmodified.

The above exception was the direct cause of the following exception:

Traceback (most recent call last): File "scripts/train_vqgan.py", line 202, in main() File "scripts/train_vqgan.py", line 32, in main rng = jax.random.PRNGKey(config.seed) File "/home/nonsans/miniforge3/envs/viper/lib/python3.8/site-packages/jax/_src/random.py", line 160, in PRNGKey key = prng.seed_with_impl(impl, seed) File "/home/nonsans/miniforge3/envs/viper/lib/python3.8/site-packages/jax/_src/prng.py", line 406, in seed_with_impl return random_seed(seed, impl=impl) File "/home/nonsans/miniforge3/envs/viper/lib/python3.8/site-packages/jax/_src/prng.py", line 690, in random_seed return random_seed_p.bind(seeds_arr, impl=impl) File "/home/nonsans/miniforge3/envs/viper/lib/python3.8/site-packages/jax/_src/core.py", line 380, in bind return self.bind_with_trace(find_top_trace(args), args, params) File "/home/nonsans/miniforge3/envs/viper/lib/python3.8/site-packages/jax/_src/core.py", line 383, in bind_with_trace out = trace.process_primitive(self, map(trace.full_raise, args), params) File "/home/nonsans/miniforge3/envs/viper/lib/python3.8/site-packages/jax/_src/core.py", line 815, in process_primitive return primitive.impl(*tracers, **params) File "/home/nonsans/miniforge3/envs/viper/lib/python3.8/site-packages/jax/_src/prng.py", line 702, in random_seed_impl base_arr = random_seed_impl_base(seeds, impl=impl) File "/home/nonsans/miniforge3/envs/viper/lib/python3.8/site-packages/jax/_src/prng.py", line 707, in random_seed_impl_base return seed(seeds) File "/home/nonsans/miniforge3/envs/viper/lib/python3.8/site-packages/jax/_src/prng.py", line 936, in threefry_seed return _threefry_seed(seed) jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.

I already saw an issue about this which was already closed but after trying anything that was mentioned there i still could not get it to run. I have tried the following things already:

pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html pip install --force-reinstall "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html XLA_PYTHON_CLIENT_PREALLOCATE=false XLA_PYTHON_CLIENT_MEM_FRACTION=.05 XLA_PYTHON_CLIENT_ALLOCATOR=platform none of those changed the error message, I always got the error above no matter what i did unfortunately. For System Information and further setup information see below. I hope someone can help me because this is of really really high importance to me as I need to use Jax for an important project. I would appreciate any help i can get. The information on hardware below is the setup i have sudo access to and since i get the same error there i was hoping to be able to fix it there and see what i can do after that because the actual goal is to run this on a server i dont have sudo access to so things like reinstalling cuda stuff wont be possible.

System info (python version, jaxlib version, accelerator, etc.)

System info (python version, jaxlib version, accelerator, etc.)

Hardware GPU: RTX 3060 12GB CUDA version 12.5 RAM: 16GB

Python Version: 3.8 Jax version: 0.4.13

jax.print_environment_info() jax: 0.4.13 jaxlib: 0.4.13 numpy: 1.24.4 python: 3.8.19 | packaged by conda-forge | (default, Mar 20 2024, 12:47:35) [GCC 12.3.0] jax.devices (1 total, 1 local): [gpu(id=0)] process_count: 1

$ nvidia-smi Fri Jul 19 16:06:06 2024 +-----------------------------------------------------------------------------------------+ | NVIDIA-SMI 555.58.02 Driver Version: 555.58.02 CUDA Version: 12.5 | |-----------------------------------------+------------------------+----------------------+ | GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | | Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | | | | MIG M. | |=========================================+========================+======================| | 0 NVIDIA GeForce RTX 3060 Off | 00000000:10:00.0 On | N/A | | 0% 52C P0 35W / 170W | 461MiB / 12288MiB | 0% Default | | | | N/A | +-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+ | Processes: | | GPU GI CI PID Type Process name GPU Memory | | ID ID Usage | |=========================================================================================| | 0 N/A N/A 1421 G /usr/lib/Xorg 142MiB | | 0 N/A N/A 1581 G /usr/bin/gnome-shell 47MiB | | 0 N/A N/A 2400 G ...seed-version=20240715-180505.020000 37MiB | | 0 N/A N/A 2688 G ...yOnDemand --variations-seed-version 36MiB | | 0 N/A N/A 4159 G /usr/bin/kgx 9MiB | | 0 N/A N/A 4672 G /usr/bin/gnome-system-monitor 12MiB | | 0 N/A N/A 9935 G /usr/bin/nautilus 12MiB | | 0 N/A N/A 10293 G /usr/bin/gnome-text-editor 7MiB | | 0 N/A N/A 19157 C python 104MiB | +-----------------------------------------------------------------------------------------+

jakevdp commented 1 month ago

This error can come up when JAX is used with an incompatible CUDNN version. You report using jax 0.4.13 with CUDA 12.5 – that's a fairly old JAX version with a much more recent CUDA version, so I wouldn't be surprised if there are some incompatibilities. We've certainly never tested jax 0.4.13 with CUDA 12.5... I'd suggest updating to a more recent JAX version and see if that helps.

NonsansWD commented 1 month ago

Thank you very much for pointing that out. I was sure i have tried to adjust the versions before but I probably messed up by having a different version support of jax installed when trying a different version from what I installed with jax. I tried it again and was extra cautious with having the same versions in place and now it works totally fine. With jax 0.4.13 there is no compatibility to the major version update of cudnn to 9.x I hope anyone who runs into the same issue will see this and be able to make it work too. I will now close this as it is solved.