Open epignatelli opened 7 months ago
The issue is in your error message:
Found cuBLAS version 120205, but JAX was built against version 120304, which is newer.
You'll need to install cuBLAS version 12.3 or later when using the pre-built wheels for jax v0.4.24, as those wheels are built against cuda 12.3.
Since you used jax[cuda12_pip]
, you should get compatible cuda automatically. My guess is that you have another CUDA installation on your system that is taking precedence. I would suggest figuring out where those other CUDA sources are installed, and either removing them, or ensuring that the pip-installed CUDA sources take precedence in your path.
Since you used jax[cuda12_pip], you should get compatible cuda automatically.
Yeah, that's why I thought the problem was not on my side somehow, thanks very much for the info.
I would expect that jax[cuda12_pip]
prioritises the CUDA installed from pip, though, no?
Also, I didn't know which CUDA version was v0.4.24 targeting. Perhaps it would be useful to have the minimum CUDA version in https://github.com/google/jax/?tab=readme-ov-file#installation? It used to be there, IIRC.
jax[cuda12_pip]
installs the correct CUDA sources in your Python site_packages
.
However, if you have other CUDA installations on your system, and your system is set up to load those other sources, they may be loaded before the ones installed with pip
. There is nothing that JAX or pip can do about this: it is a property of your system.
jax 0.4.24 targets CUDA 12.2, and jax 0.4.25 targets CUDA 12.3, if that helps.
Okay, great! Thanks, Jake!
jax[cuda12_pip]
installs the correct CUDA sources in your Pythonsite_packages
. However, if you have other CUDA installations on your system, and your system is set up to load those other sources, they may be loaded before the ones installed withpip
. There is nothing that JAX or pip can do about this: it is a property of your system. jax 0.4.24 targets CUDA 12.2, and jax 0.4.25 targets CUDA 12.3, if that helps.
That is it. Different from most other things in python venvs, cuda seem to add itself at the end of the path. So if you have a supercomputer with modules, strangely the system's cuda will take precedence.
Other stuff installed with pip in a venv will take over the ones from the modules, because venv adds them at the beginning of the python path, for example. Except for cuda.
Unloading the cuda module forces it to find the one from site_packages and things work.
Go figure.
Is there a proper way to check the location of the set of CUDA libraries that jax is loading?
TF_CPP_MIN_LOG_LEVEL=0
is not showing anything more than usual.
Sorry, I have got to reopen this. Here's the issue.
Different packages both install CUDA binaries (e.g., JAX and pytorch). These CUDA binaries have different versions.
Using jax[cuda12_pip]
or similar makes the two requirements conflict.
Using jax[cuda12_local]
and having a local version of CUDA (e.g., in a cluster) does not help because JAX prefers loading the binaries installed via pip
by some other package.
This happens despite setting XLA_ARGS="--xla_gpu_cuda_data_dir=<path-to-CUDA>"
.
Is there a way to better diagnose the issue? For example, printing the folder that JAX uses to load the CUDA binaries? (TF_CPP_MIN_LOG_LEVEL=-0
does not help).
Is there any other way to force JAX loading from a different folder?
@epignatelli It's currently not possible to override the search path, because it's set by an RPATH
on the .so files in jaxlib. We could perhaps fix that with some cunning.
However, note that JAX 0.4.26 relaxed its CUDA version constraints. You should be able to just use the same version of CUDA that Pytorch is using (12.1, last I checked).
Thank very much, @hawkinsp. I resorted to that in the meanwhile but JAX v0.4.23 introduces breaking changes to the API. I could access those libraries and align them, but that's not a very general case.
Hey @hawkinsp still issues here.
I migrated to the latest JAX version with pip install "jax[cuda12]"
.
I removed all cuda path references in LD_LIBRARY_PATH
and PATH
itself, but for some reason JAX still tries to load an older version of the drivers.
I get
The NVIDIA driver's CUDA version is 12.4 which is older than the PTX compiler version (12.6.20). Because the driver is older than the PTX compiler version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.
with
echo $LD_LIBRARY_PATH
/lib64:/lib:/usr/local/lib:/usr/local/lib64:/usr/lib64:/usr/lib:
and (empty)
echo $PATH | grep -i cuda
My nvidia-smi
shows:
NVIDIA-SMI 550.78 Driver Version: 550.78 CUDA Version: 12.4
But I guess JAX does not use this CUDA, but installs its own, correct? Where should I look at to address this?
But I guess JAX does not use this CUDA, but installs its own, correct?
With pip install jax[cuda12]
, it will download its own CUDA; with pip install jax[cuda12_local]
, it will use the local CUDA. See https://jax.readthedocs.io/en/latest/installation.html#nvidia-gpu for details.
If you are seeing this error, and have both a locally-installed CUDA and a pip-installed CUDA on your machine and have adjusted LD_LIBRARY_PATH
, then I suspect the issue is cross-talk between your two CUDA installations.
Thanks, @jakevdp! That's what I don't understand.
Where could the locally installed CUDA be set as higher priority than the pip-installed one, if LD_LIBRARY_PATH
and PATH
are clean? Do you have any idea?
Description
The latest JAX
0.4.24
does not detect GPUs, both using local cuda and pip-installed cuda. The latest working version for me is0.4.23
.Reproduce
pip install "jax[cuda12_pip]"==0.4.24 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
With
python -c "import jax; jax.devices('gpu')"
, I get:System info (python version, jaxlib version, accelerator, etc.)
nvcc --version
nvidia-smi