google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.75k stars 2.71k forks source link

JAX gpu working only with torch import #16747

Closed gpadiolleau closed 1 year ago

gpadiolleau commented 1 year ago

Description

I have a conda environnement named PyTorchGPU where I installed torchGPU at first and then installed JAX and jaxlib with Nvidia GPU support. It seems running JAX only does not allow for cudnn to use the GPU but if I import torch first, it magically works !

Running JAX only

When I run a simple test on GPU in the terminal : (PyTorchGPU) ubuntu@ubuntu-GL62MVR-7RFX:~$ JAX_PLATFORM_NAME=gpu python -c "from jax.random import PRNGKey; PRNGKey(10)" , I get the following error :

2023-07-15 11:51:17.450237: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:439] Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR
2023-07-15 11:51:17.450316: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:443] Memory usage: 1513750528 bytes free, 6367477760 bytes total.
Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/home/ubuntu/anaconda3/envs/PyTorchGPU/lib/python3.8/site-packages/jax/_src/random.py", line 160, in PRNGKey
    key = prng.seed_with_impl(impl, seed)
  File "/home/ubuntu/anaconda3/envs/PyTorchGPU/lib/python3.8/site-packages/jax/_src/prng.py", line 406, in seed_with_impl
    return random_seed(seed, impl=impl)
  File "/home/ubuntu/anaconda3/envs/PyTorchGPU/lib/python3.8/site-packages/jax/_src/prng.py", line 690, in random_seed
    return random_seed_p.bind(seeds_arr, impl=impl)
  File "/home/ubuntu/anaconda3/envs/PyTorchGPU/lib/python3.8/site-packages/jax/_src/prng.py", line 702, in random_seed_impl
    base_arr = random_seed_impl_base(seeds, impl=impl)
  File "/home/ubuntu/anaconda3/envs/PyTorchGPU/lib/python3.8/site-packages/jax/_src/prng.py", line 707, in random_seed_impl_base
    return seed(seeds)
  File "/home/ubuntu/anaconda3/envs/PyTorchGPU/lib/python3.8/site-packages/jax/_src/prng.py", line 936, in threefry_seed
    return _threefry_seed(seed)
  File "/home/ubuntu/anaconda3/envs/PyTorchGPU/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/ubuntu/anaconda3/envs/PyTorchGPU/lib/python3.8/site-packages/jax/_src/pjit.py", line 250, in cache_miss
    outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
  File "/home/ubuntu/anaconda3/envs/PyTorchGPU/lib/python3.8/site-packages/jax/_src/pjit.py", line 163, in _python_pjit_helper
    out_flat = pjit_p.bind(*args_flat, **params)
  File "/home/ubuntu/anaconda3/envs/PyTorchGPU/lib/python3.8/site-packages/jax/_src/core.py", line 2677, in bind
    return self.bind_with_trace(top_trace, args, params)
  File "/home/ubuntu/anaconda3/envs/PyTorchGPU/lib/python3.8/site-packages/jax/_src/core.py", line 383, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/home/ubuntu/anaconda3/envs/PyTorchGPU/lib/python3.8/site-packages/jax/_src/core.py", line 815, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/home/ubuntu/anaconda3/envs/PyTorchGPU/lib/python3.8/site-packages/jax/_src/pjit.py", line 1203, in _pjit_call_impl
    return xc._xla.pjit(name, f, call_impl_cache_miss, [], [], donated_argnums,
  File "/home/ubuntu/anaconda3/envs/PyTorchGPU/lib/python3.8/site-packages/jax/_src/pjit.py", line 1187, in call_impl_cache_miss
    out_flat, compiled = _pjit_call_impl_python(
  File "/home/ubuntu/anaconda3/envs/PyTorchGPU/lib/python3.8/site-packages/jax/_src/pjit.py", line 1120, in _pjit_call_impl_python
    compiled = _pjit_lower(
  File "/home/ubuntu/anaconda3/envs/PyTorchGPU/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 2323, in compile
    executable = UnloadedMeshExecutable.from_hlo(
  File "/home/ubuntu/anaconda3/envs/PyTorchGPU/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 2645, in from_hlo
    xla_executable, compile_options = _cached_compilation(
  File "/home/ubuntu/anaconda3/envs/PyTorchGPU/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 2555, in _cached_compilation
    xla_executable = dispatch.compile_or_get_cached(
  File "/home/ubuntu/anaconda3/envs/PyTorchGPU/lib/python3.8/site-packages/jax/_src/dispatch.py", line 497, in compile_or_get_cached
    return backend_compile(backend, computation, compile_options,
  File "/home/ubuntu/anaconda3/envs/PyTorchGPU/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/home/ubuntu/anaconda3/envs/PyTorchGPU/lib/python3.8/site-packages/jax/_src/dispatch.py", line 465, in backend_compile
    return backend.compile(built_c, compile_options=options)
jax._src.traceback_util.UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/home/ubuntu/anaconda3/envs/PyTorchGPU/lib/python3.8/site-packages/jax/_src/random.py", line 160, in PRNGKey
    key = prng.seed_with_impl(impl, seed)
  File "/home/ubuntu/anaconda3/envs/PyTorchGPU/lib/python3.8/site-packages/jax/_src/prng.py", line 406, in seed_with_impl
    return random_seed(seed, impl=impl)
  File "/home/ubuntu/anaconda3/envs/PyTorchGPU/lib/python3.8/site-packages/jax/_src/prng.py", line 690, in random_seed
    return random_seed_p.bind(seeds_arr, impl=impl)
  File "/home/ubuntu/anaconda3/envs/PyTorchGPU/lib/python3.8/site-packages/jax/_src/core.py", line 380, in bind
    return self.bind_with_trace(find_top_trace(args), args, params)
  File "/home/ubuntu/anaconda3/envs/PyTorchGPU/lib/python3.8/site-packages/jax/_src/core.py", line 383, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/home/ubuntu/anaconda3/envs/PyTorchGPU/lib/python3.8/site-packages/jax/_src/core.py", line 815, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/home/ubuntu/anaconda3/envs/PyTorchGPU/lib/python3.8/site-packages/jax/_src/prng.py", line 702, in random_seed_impl
    base_arr = random_seed_impl_base(seeds, impl=impl)
  File "/home/ubuntu/anaconda3/envs/PyTorchGPU/lib/python3.8/site-packages/jax/_src/prng.py", line 707, in random_seed_impl_base
    return seed(seeds)
  File "/home/ubuntu/anaconda3/envs/PyTorchGPU/lib/python3.8/site-packages/jax/_src/prng.py", line 936, in threefry_seed
    return _threefry_seed(seed)
jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.

However running the following test (PyTorchGPU) ubuntu@ubuntu-GL62MVR-7RFX:~$ JAX_PLATFORM_NAME=gpu python -c "import jax;print(f'JAX available devices :',jax.devices())" gives me the following result asserting JAX detects my GPU :

JAX available devices : [gpu(id=0)]

First import torch and run JAX

Running the following test on GPU in terminal : (PyTorchGPU) ubuntu@ubuntu-GL62MVR-7RFX:~$ JAX_PLATFORM_NAME=gpu python -c "import torch;import jax;print(jax.random.PRNGKey(10))" gives me the expected result :

[ 0 10]

It is not related imports order since running JAX_PLATFORM_NAME=gpu python -c "import jax;import torch;print(jax.random.PRNGKey(10))" gives the same result.

Any ideas ?

It really seems to be a problem with cudnn kind of "activation" through torch. Maybe 'cause torch comes with it cudnn installation ? Here is my conda list cuda result :

# packages in environment at /home/ubuntu/anaconda3/envs/PyTorchGPU:
#
# Name                    Version                   Build  Channel
cuda-cudart               11.8.89                       0    nvidia
cuda-cupti                11.8.87                       0    nvidia
cuda-libraries            11.8.0                        0    nvidia
cuda-nvrtc                11.8.89                       0    nvidia
cuda-nvtx                 11.8.86                       0    nvidia
cuda-runtime              11.8.0                        0    nvidia
cudatoolkit               11.8.0               h6a678d5_0  
cudnn                     8.9.2.26               cuda11_0  
pytorch-cuda              11.8                 h7e8668a_5    pytorch

Does pytorch-cuda overwrite my local codatoolkit installation ? And then I need to import torch first to run pytorch-cuda as backend for JAX ?

Thanks for any help, I need to understand this ! :)

What jax/jaxlib version are you using?

jax 0.4.13 / jaxlib 0.4.13+cuda11.cudnn86

Which accelerator(s) are you using?

GPU

Additional system info

Python 3.8.17, Ubuntu 22.04.2 LTS

NVIDIA GPU info

image

hawkinsp commented 1 year ago

I'm speculating you have two versions of CUDNN installed. If you import torch first, you're probably getting the version that torch uses, and if you import jax first, you're getting jax's version. This might create problems if the jax version is newer, uses more GPU memory, and you run out if you use it.

Can you try lowering XLA_PYTHON_CLIENT_MEM_FRACTION ? https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html

It defaults to 0.75, try 0.7 or lower.

gpadiolleau commented 1 year ago

Hi, thanks for your help !

Yes, I have to two CUDA installations in my environment. If I run conda list cuda and conda list cudnn, I get the following :

# packages in environment at /home/ubuntu/anaconda3/envs/PyTorchGPU:
#
# Name                    Version                   Build  Channel
cuda-cudart               11.8.89                       0    nvidia
cuda-cupti                11.8.87                       0    nvidia
cuda-libraries            11.8.0                        0    nvidia
cuda-nvrtc                11.8.89                       0    nvidia
cuda-nvtx                 11.8.86                       0    nvidia
cuda-runtime              11.8.0                        0    nvidia
cudatoolkit               11.8.0               h6a678d5_0  
pytorch-cuda              11.8                 h7e8668a_5    pytorch
# packages in environment at /home/ubuntu/anaconda3/envs/PyTorchGPU:
#
# Name                    Version                   Build  Channel
cudnn                     8.9.2.26               cuda11_0

But, as you can see, torch comes with its own CUDA installation pytorch-cuda. However, both are CUDA version 11.8 (and I can't find torch CUDNN version)

As I said in my first message, the problem does not seem to come from import's order since importing torch first or jax first does not change anything : it works when I have import torch before any jax operations.

I do not think decreasing XL_PYTHON_CLIENT_MEM_FRACTION will work since Memory usage is less than total bytes in error message, but I may have misunderstood this information. I ran the same test with XL_PYTHON_CLIENT_MEM_FRACTION=0.10 :

JAX_PLATFORM_NAME=gpu XLA_PYTHON_CLIENT_MEM_FRACTION=0.10 python -c "from jax.random import PRNGKey; print(PRNGKey(0));"

but I got the same error :

2023-07-18 11:50:19.639897: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:439] Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR
2023-07-18 11:50:19.639960: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:443] Memory usage: 5653528576 bytes free, 6367477760 bytes total.
Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/home/ubuntu/anaconda3/envs/PyTorchGPU/lib/python3.8/site-packages/jax/_src/random.py", line 160, in PRNGKey
    key = prng.seed_with_impl(impl, seed)
  File "/home/ubuntu/anaconda3/envs/PyTorchGPU/lib/python3.8/site-packages/jax/_src/prng.py", line 406, in seed_with_impl
    return random_seed(seed, impl=impl)
  File "/home/ubuntu/anaconda3/envs/PyTorchGPU/lib/python3.8/site-packages/jax/_src/prng.py", line 690, in random_seed
    return random_seed_p.bind(seeds_arr, impl=impl)
  File "/home/ubuntu/anaconda3/envs/PyTorchGPU/lib/python3.8/site-packages/jax/_src/prng.py", line 702, in random_seed_impl
    base_arr = random_seed_impl_base(seeds, impl=impl)
  File "/home/ubuntu/anaconda3/envs/PyTorchGPU/lib/python3.8/site-packages/jax/_src/prng.py", line 707, in random_seed_impl_base
    return seed(seeds)
  File "/home/ubuntu/anaconda3/envs/PyTorchGPU/lib/python3.8/site-packages/jax/_src/prng.py", line 936, in threefry_seed
    return _threefry_seed(seed)
  File "/home/ubuntu/anaconda3/envs/PyTorchGPU/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/ubuntu/anaconda3/envs/PyTorchGPU/lib/python3.8/site-packages/jax/_src/pjit.py", line 250, in cache_miss
    outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
  File "/home/ubuntu/anaconda3/envs/PyTorchGPU/lib/python3.8/site-packages/jax/_src/pjit.py", line 163, in _python_pjit_helper
    out_flat = pjit_p.bind(*args_flat, **params)
  File "/home/ubuntu/anaconda3/envs/PyTorchGPU/lib/python3.8/site-packages/jax/_src/core.py", line 2677, in bind
    return self.bind_with_trace(top_trace, args, params)
  File "/home/ubuntu/anaconda3/envs/PyTorchGPU/lib/python3.8/site-packages/jax/_src/core.py", line 383, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/home/ubuntu/anaconda3/envs/PyTorchGPU/lib/python3.8/site-packages/jax/_src/core.py", line 815, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/home/ubuntu/anaconda3/envs/PyTorchGPU/lib/python3.8/site-packages/jax/_src/pjit.py", line 1203, in _pjit_call_impl
    return xc._xla.pjit(name, f, call_impl_cache_miss, [], [], donated_argnums,
  File "/home/ubuntu/anaconda3/envs/PyTorchGPU/lib/python3.8/site-packages/jax/_src/pjit.py", line 1187, in call_impl_cache_miss
    out_flat, compiled = _pjit_call_impl_python(
  File "/home/ubuntu/anaconda3/envs/PyTorchGPU/lib/python3.8/site-packages/jax/_src/pjit.py", line 1120, in _pjit_call_impl_python
    compiled = _pjit_lower(
  File "/home/ubuntu/anaconda3/envs/PyTorchGPU/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 2323, in compile
    executable = UnloadedMeshExecutable.from_hlo(
  File "/home/ubuntu/anaconda3/envs/PyTorchGPU/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 2645, in from_hlo
    xla_executable, compile_options = _cached_compilation(
  File "/home/ubuntu/anaconda3/envs/PyTorchGPU/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 2555, in _cached_compilation
    xla_executable = dispatch.compile_or_get_cached(
  File "/home/ubuntu/anaconda3/envs/PyTorchGPU/lib/python3.8/site-packages/jax/_src/dispatch.py", line 497, in compile_or_get_cached
    return backend_compile(backend, computation, compile_options,
  File "/home/ubuntu/anaconda3/envs/PyTorchGPU/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/home/ubuntu/anaconda3/envs/PyTorchGPU/lib/python3.8/site-packages/jax/_src/dispatch.py", line 465, in backend_compile
    return backend.compile(built_c, compile_options=options)
jax._src.traceback_util.UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/home/ubuntu/anaconda3/envs/PyTorchGPU/lib/python3.8/site-packages/jax/_src/random.py", line 160, in PRNGKey
    key = prng.seed_with_impl(impl, seed)
  File "/home/ubuntu/anaconda3/envs/PyTorchGPU/lib/python3.8/site-packages/jax/_src/prng.py", line 406, in seed_with_impl
    return random_seed(seed, impl=impl)
  File "/home/ubuntu/anaconda3/envs/PyTorchGPU/lib/python3.8/site-packages/jax/_src/prng.py", line 690, in random_seed
    return random_seed_p.bind(seeds_arr, impl=impl)
  File "/home/ubuntu/anaconda3/envs/PyTorchGPU/lib/python3.8/site-packages/jax/_src/core.py", line 380, in bind
    return self.bind_with_trace(find_top_trace(args), args, params)
  File "/home/ubuntu/anaconda3/envs/PyTorchGPU/lib/python3.8/site-packages/jax/_src/core.py", line 383, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/home/ubuntu/anaconda3/envs/PyTorchGPU/lib/python3.8/site-packages/jax/_src/core.py", line 815, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/home/ubuntu/anaconda3/envs/PyTorchGPU/lib/python3.8/site-packages/jax/_src/prng.py", line 702, in random_seed_impl
    base_arr = random_seed_impl_base(seeds, impl=impl)
  File "/home/ubuntu/anaconda3/envs/PyTorchGPU/lib/python3.8/site-packages/jax/_src/prng.py", line 707, in random_seed_impl_base
    return seed(seeds)
  File "/home/ubuntu/anaconda3/envs/PyTorchGPU/lib/python3.8/site-packages/jax/_src/prng.py", line 936, in threefry_seed
    return _threefry_seed(seed)
jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.

Is it possible that the problem is caused by my jaxlib installation ? I have jaxlib version 0.4.13+cuda11.cudnn86 since this is the latest version for CUDA 11, but my environment run CUDNN 8.9 for CUDA 11. I installed jaxlib version 0.4.13+cuda11.cudnn86 since version for CUDNN 8.9 are only available for CUDA 12 according to https://storage.googleapis.com/jax-releases/jax_cuda_releases.html, but torch is only available for CUDA 11 (hahahaha....)

p.s. : Is it normal that decreasing XLA_PYTHON_CLIENT_MEM_FRACTION results in increasing Memory usage in error message ?

hawkinsp commented 1 year ago

I suspect your multiple CUDA installations are confusing something. Note that torch and jax actually have conflicting CUDA requirements.

I recommend using separate virtual environments for torch and jax, and installing jax using the cuda12_pip wheels per the README.

I'm not sure there's any action we can take here.

gpadiolleau commented 1 year ago

I thought that installing jax and torch for the same CUDA and CUDNN version would have work.

I created a new environment and install jax with pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html.

Re-test JAX_PLATFORM_NAME=gpu python -c "from jax.random import PRNGKey; print(PRNGKey(0));" and it works like a charm !

I'll try later to get an environment with local CUDA and CUDNN with torch and jax. Thanks for your help, and btw thanks for making DL libraries with Nvidia GPU supports simpler to install ! I remember my passed long hours battles against Tensorflow and Nvidia drivers ! haha

patrick-kidger commented 11 months ago

FWIW I've bumped into the same error, whilst needing both PyTorch and JAX in the same environment. (I'm using PyTorch just because JAX doesn't provide any dataloaders.)

The fix seems to be running the following before the rest of my program:

import jax
jax.random.PRNGKey(0)  # Initialise JAX's backend by performing a nontrivial JAX operation.
import torch  # Afterwards!

Notably this is not the same fix described above, so I wanted to share it here in case it helped anyone out.

hawkinsp commented 11 months ago

@patrick-kidger Yes, that's likely because JAX wants a newer CUDA/CUDNN release than PyTorch does, and by loading the newer one first, things probably work.

The safer option is to install a CPU-only pytorch, though.

There's not a ton we can do about this, other than perhaps adding some better errors if you get it wrong.

hawkinsp commented 11 months ago

I strongly suspect that https://github.com/google/jax/commit/9404518201c5ac8af6c85ecdf12a7cc34c102585 will help catch this case. We will now raise an exception if we detect that the versions of the CUDA libraries loaded are too old.