jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.28k stars 2.78k forks source link

Latest JAX `0.4.24` does not detect GPUs #20133

Open epignatelli opened 7 months ago

epignatelli commented 7 months ago

Description

The latest JAX 0.4.24 does not detect GPUs, both using local cuda and pip-installed cuda. The latest working version for me is 0.4.23.

Reproduce

pip install "jax[cuda12_pip]"==0.4.24 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

With python -c "import jax; jax.devices('gpu')", I get:

CUDA backend failed to initialize: Found cuBLAS version 120205, but JAX was built against version 120304, which is newer. The copy of cuBLAS that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File ~/miniforge3/envs/navix/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 872, in devices
    return get_backend(backend).devices()
  File "~/miniforge3/envs/navix/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 806, in get_backend
    return _get_backend_uncached(platform)
  File "~/miniforge3/envs/navix/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 788, in _get_backend_uncached
    platform = canonicalize_platform(platform)
  File "~/miniforge3/envs/navix/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 614, in canonicalize_platform
    raise RuntimeError(f"Unknown backend: '{platform}' requested, but no "
RuntimeError: Unknown backend: 'gpu' requested, but no platforms that are instances of gpu are present. Platforms present are: cpu

System info (python version, jaxlib version, accelerator, etc.)

nvcc --version

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2023 NVIDIA Corporation
Built on Tue_Aug_15_22:02:13_PDT_2023
Cuda compilation tools, release 12.2, V12.2.140
Build cuda_12.2.r12.2/compiler.33191640_0

nvidia-smi

+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.129.03             Driver Version: 535.129.03   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
jakevdp commented 7 months ago

The issue is in your error message:

Found cuBLAS version 120205, but JAX was built against version 120304, which is newer.

You'll need to install cuBLAS version 12.3 or later when using the pre-built wheels for jax v0.4.24, as those wheels are built against cuda 12.3.

Since you used jax[cuda12_pip], you should get compatible cuda automatically. My guess is that you have another CUDA installation on your system that is taking precedence. I would suggest figuring out where those other CUDA sources are installed, and either removing them, or ensuring that the pip-installed CUDA sources take precedence in your path.

epignatelli commented 7 months ago

Since you used jax[cuda12_pip], you should get compatible cuda automatically.

Yeah, that's why I thought the problem was not on my side somehow, thanks very much for the info.

I would expect that jax[cuda12_pip] prioritises the CUDA installed from pip, though, no?

Also, I didn't know which CUDA version was v0.4.24 targeting. Perhaps it would be useful to have the minimum CUDA version in https://github.com/google/jax/?tab=readme-ov-file#installation? It used to be there, IIRC.

jakevdp commented 7 months ago

jax[cuda12_pip] installs the correct CUDA sources in your Python site_packages.

However, if you have other CUDA installations on your system, and your system is set up to load those other sources, they may be loaded before the ones installed with pip. There is nothing that JAX or pip can do about this: it is a property of your system.

jax 0.4.24 targets CUDA 12.2, and jax 0.4.25 targets CUDA 12.3, if that helps.

epignatelli commented 7 months ago

Okay, great! Thanks, Jake!

surak commented 7 months ago

jax[cuda12_pip] installs the correct CUDA sources in your Python site_packages. However, if you have other CUDA installations on your system, and your system is set up to load those other sources, they may be loaded before the ones installed with pip. There is nothing that JAX or pip can do about this: it is a property of your system. jax 0.4.24 targets CUDA 12.2, and jax 0.4.25 targets CUDA 12.3, if that helps.

That is it. Different from most other things in python venvs, cuda seem to add itself at the end of the path. So if you have a supercomputer with modules, strangely the system's cuda will take precedence.

Other stuff installed with pip in a venv will take over the ones from the modules, because venv adds them at the beginning of the python path, for example. Except for cuda.

Unloading the cuda module forces it to find the one from site_packages and things work.

Go figure.

epignatelli commented 6 months ago

Is there a proper way to check the location of the set of CUDA libraries that jax is loading? TF_CPP_MIN_LOG_LEVEL=0 is not showing anything more than usual.

epignatelli commented 6 months ago

Sorry, I have got to reopen this. Here's the issue.

Different packages both install CUDA binaries (e.g., JAX and pytorch). These CUDA binaries have different versions.

Using jax[cuda12_pip] or similar makes the two requirements conflict. Using jax[cuda12_local] and having a local version of CUDA (e.g., in a cluster) does not help because JAX prefers loading the binaries installed via pip by some other package.

This happens despite setting XLA_ARGS="--xla_gpu_cuda_data_dir=<path-to-CUDA>".

Is there a way to better diagnose the issue? For example, printing the folder that JAX uses to load the CUDA binaries? (TF_CPP_MIN_LOG_LEVEL=-0 does not help). Is there any other way to force JAX loading from a different folder?

hawkinsp commented 6 months ago

@epignatelli It's currently not possible to override the search path, because it's set by an RPATH on the .so files in jaxlib. We could perhaps fix that with some cunning.

However, note that JAX 0.4.26 relaxed its CUDA version constraints. You should be able to just use the same version of CUDA that Pytorch is using (12.1, last I checked).

epignatelli commented 6 months ago

Thank very much, @hawkinsp. I resorted to that in the meanwhile but JAX v0.4.23 introduces breaking changes to the API. I could access those libraries and align them, but that's not a very general case.

epignatelli commented 2 months ago

Hey @hawkinsp still issues here.

I migrated to the latest JAX version with pip install "jax[cuda12]". I removed all cuda path references in LD_LIBRARY_PATH and PATH itself, but for some reason JAX still tries to load an older version of the drivers.

I get

The NVIDIA driver's CUDA version is 12.4 which is older than the PTX compiler version (12.6.20). Because the driver is older than the PTX compiler version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.

with

echo $LD_LIBRARY_PATH
/lib64:/lib:/usr/local/lib:/usr/local/lib64:/usr/lib64:/usr/lib:

and (empty)

echo $PATH | grep -i cuda

My nvidia-smi shows:

NVIDIA-SMI 550.78                 Driver Version: 550.78         CUDA Version: 12.4

But I guess JAX does not use this CUDA, but installs its own, correct? Where should I look at to address this?

jakevdp commented 2 months ago

But I guess JAX does not use this CUDA, but installs its own, correct?

With pip install jax[cuda12], it will download its own CUDA; with pip install jax[cuda12_local], it will use the local CUDA. See https://jax.readthedocs.io/en/latest/installation.html#nvidia-gpu for details.

If you are seeing this error, and have both a locally-installed CUDA and a pip-installed CUDA on your machine and have adjusted LD_LIBRARY_PATH, then I suspect the issue is cross-talk between your two CUDA installations.

epignatelli commented 2 months ago

Thanks, @jakevdp! That's what I don't understand.

Where could the locally installed CUDA be set as higher priority than the pip-installed one, if LD_LIBRARY_PATH and PATH are clean? Do you have any idea?