jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.61k stars 2.82k forks source link

JAX does not recognise my NVIDIA GPU when installed via conda #24604

Open zairving opened 1 month ago

zairving commented 1 month ago

Description

I previously had a working installation of JAX (installed via conda) that recognised my NVIDIA GPU without issue. However, I recently migrated to a new machine and now I cannot get JAX to recognise my GPU when I install via conda. I'm using Miniforge to manage my conda environments, as I did on my old machine, and I installed JAX according to the docs: conda install jaxlib=*=*cuda* jax cuda-nvcc -c conda-forge -c nvidia When I then try to import JAX and check my available devices using:

import jax

print(jax.devices())

I get the following output:

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
[CpuDevice(id=0)]

tensorflow, however, does recognise my GPU, and so I tried the suggestion from #15268 to install using pip. I created a new environment and ran: pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html When I then ran my JAX code above I got the output:

[CudaDevice(id=0)]

Voilà, my GPU has been found!

It therefore appears that the conda section of the docs might need updating.

System info (python version, jaxlib version, accelerator, etc.)

Conda installation:

jax:    0.4.34
jaxlib: 0.4.34
numpy:  2.1.2
python: 3.13.0 | packaged by conda-forge | (main, Oct  8 2024, 20:04:32) [GCC 13.3.0]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='Merlin', release='6.8.0-47-generic', version='#47-Ubuntu SMP PREEMPT_DYNAMIC Fri Sep 27 21:40:26 UTC 2024', machine='x86_64')

$ nvidia-smi
Tue Oct 29 22:50:56 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.120                Driver Version: 550.120        CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA GeForce RTX 3090        Off |   00000000:01:00.0  On |                  N/A |
|  0%   50C    P3             84W /  420W |    1077MiB /  24576MiB |     33%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A      2582      G   /usr/lib/xorg/Xorg                            432MiB |
|    0   N/A  N/A      2850      G   /usr/bin/gnome-shell                          108MiB |
|    0   N/A  N/A      4140      G   ...irefox/4173/usr/lib/firefox/firefox        155MiB |
|    0   N/A  N/A      5602      G   ...erProcess --variations-seed-version        163MiB |
|    0   N/A  N/A     28658      G   /usr/bin/nautilus                              41MiB |
|    0   N/A  N/A     28840      G   /usr/bin/gnome-text-editor                     12MiB |
|    0   N/A  N/A     36942      G   /usr/bin/nvidia-settings                        0MiB |
+-----------------------------------------------------------------------------------------+

# Name                    Version                   Build  Channel
_libgcc_mutex             0.1                 conda_forge    conda-forge
_openmp_mutex             4.5                       2_gnu    conda-forge
asttokens                 2.4.1              pyhd8ed1ab_0    conda-forge
binutils_impl_linux-64    2.43                 h4bf12b8_2    conda-forge
binutils_linux-64         2.43                 h4852527_2    conda-forge
bzip2                     1.0.8                h4bc722e_7    conda-forge
c-ares                    1.34.2               heb4867d_0    conda-forge
ca-certificates           2024.8.30            hbcca054_0    conda-forge
comm                      0.2.2              pyhd8ed1ab_0    conda-forge
cuda-cccl                 12.6.77                       0    nvidia
cuda-cccl_linux-64        12.6.77                       0    nvidia
cuda-crt-dev_linux-64     12.6.77                       0    nvidia
cuda-crt-tools            12.6.77                       0    nvidia
cuda-cudart               12.6.77              h5888daf_0    conda-forge
cuda-cudart-dev           12.4.127                      0    nvidia
cuda-cudart-dev_linux-64  12.6.77                       0    nvidia
cuda-cudart-static_linux-64 12.6.77                       0    nvidia
cuda-cudart_linux-64      12.6.77              h3f2d84a_0    conda-forge
cuda-cupti                12.6.80              hbd13f7d_0    conda-forge
cuda-driver-dev_linux-64  12.6.77                       0    nvidia
cuda-nvcc                 12.6.77                       0    nvidia
cuda-nvcc-dev_linux-64    12.6.77                       0    nvidia
cuda-nvcc-impl            12.6.77                       0    nvidia
cuda-nvcc-tools           12.6.77                       0    nvidia
cuda-nvcc_linux-64        12.6.77                       0    nvidia
cuda-nvrtc                12.6.77              hbd13f7d_0    conda-forge
cuda-nvtx                 12.6.77              hbd13f7d_0    conda-forge
cuda-nvvm-dev_linux-64    12.6.77                       0    nvidia
cuda-nvvm-impl            12.6.77                       0    nvidia
cuda-nvvm-tools           12.6.77                       0    nvidia
cuda-version              12.6                 h7480c83_3    conda-forge
cudnn                     9.3.0.75             h93bb076_0    conda-forge
debugpy                   1.8.7           py313h46c70d0_0    conda-forge
decorator                 5.1.1              pyhd8ed1ab_0    conda-forge
exceptiongroup            1.2.2              pyhd8ed1ab_0    conda-forge
executing                 2.1.0              pyhd8ed1ab_0    conda-forge
gcc_impl_linux-64         12.4.0               hb2e57f8_1    conda-forge
gcc_linux-64              12.4.0               h6b7512a_5    conda-forge
gxx_impl_linux-64         12.4.0               h613a52c_1    conda-forge
gxx_linux-64              12.4.0               h8489865_5    conda-forge
importlib-metadata        8.5.0              pyha770c72_0    conda-forge
ipykernel                 6.29.5             pyh3099207_0    conda-forge
ipython                   8.29.0             pyh707e725_0    conda-forge
jax                       0.4.34             pyhd8ed1ab_0    conda-forge
jaxlib                    0.4.34          cuda120py313h3b1fb80_200    conda-forge
jedi                      0.19.1             pyhd8ed1ab_0    conda-forge
jupyter_client            8.6.3              pyhd8ed1ab_0    conda-forge
jupyter_core              5.7.2              pyh31011fe_1    conda-forge
kernel-headers_linux-64   3.10.0              he073ed8_18    conda-forge
keyutils                  1.6.1                h166bdaf_0    conda-forge
krb5                      1.21.3               h659f571_0    conda-forge
ld_impl_linux-64          2.43                 h712a8e2_2    conda-forge
libabseil                 20240722.0      cxx17_h5888daf_1    conda-forge
libblas                   3.9.0           25_linux64_openblas    conda-forge
libcblas                  3.9.0           25_linux64_openblas    conda-forge
libcublas                 12.6.3.3             hbd13f7d_1    conda-forge
libcufft                  11.3.0.4             hbd13f7d_0    conda-forge
libcurand                 10.3.7.77            hbd13f7d_0    conda-forge
libcusolver               11.7.1.2             hbd13f7d_0    conda-forge
libcusparse               12.5.4.2             hbd13f7d_0    conda-forge
libedit                   3.1.20191231         he28a2e2_2    conda-forge
libexpat                  2.6.3                h5888daf_0    conda-forge
libffi                    3.4.2                h7f98852_5    conda-forge
libgcc                    14.2.0               h77fa898_1    conda-forge
libgcc-devel_linux-64     12.4.0             ha4f9413_101    conda-forge
libgcc-ng                 14.2.0               h69a702a_1    conda-forge
libgfortran               14.2.0               h69a702a_1    conda-forge
libgfortran-ng            14.2.0               h69a702a_1    conda-forge
libgfortran5              14.2.0               hd5240d6_1    conda-forge
libgomp                   14.2.0               h77fa898_1    conda-forge
libgrpc                   1.65.5               hf5c653b_0    conda-forge
liblapack                 3.9.0           25_linux64_openblas    conda-forge
libmpdec                  4.0.0                h4bc722e_0    conda-forge
libnvjitlink              12.6.77              hbd13f7d_1    conda-forge
libopenblas               0.3.28          pthreads_h94d23a6_0    conda-forge
libprotobuf               5.27.5               h5b01275_2    conda-forge
libre2-11                 2024.07.02           hbbce691_1    conda-forge
libsanitizer              12.4.0               h46f95d5_1    conda-forge
libsodium                 1.0.20               h4ab18f5_0    conda-forge
libsqlite                 3.47.0               hadc24fc_1    conda-forge
libstdcxx                 14.2.0               hc0a3c3a_1    conda-forge
libstdcxx-devel_linux-64  12.4.0             ha4f9413_101    conda-forge
libstdcxx-ng              14.2.0               h4852527_1    conda-forge
libuuid                   2.38.1               h0b41bf4_0    conda-forge
libzlib                   1.3.1                hb9d3cd8_2    conda-forge
matplotlib-inline         0.1.7              pyhd8ed1ab_0    conda-forge
ml_dtypes                 0.5.0           py313ha87cce1_0    conda-forge
nccl                      2.23.4.1             h52f6c39_1    conda-forge
ncurses                   6.5                  he02047a_1    conda-forge
nest-asyncio              1.6.0              pyhd8ed1ab_0    conda-forge
numpy                     2.1.2           py313h4bf6692_0    conda-forge
openssl                   3.3.2                hb9d3cd8_0    conda-forge
opt-einsum                3.4.0                hd8ed1ab_0    conda-forge
opt_einsum                3.4.0              pyhd8ed1ab_0    conda-forge
packaging                 24.1               pyhd8ed1ab_0    conda-forge
parso                     0.8.4              pyhd8ed1ab_0    conda-forge
pexpect                   4.9.0              pyhd8ed1ab_0    conda-forge
pickleshare               0.7.5                   py_1003    conda-forge
pip                       24.3.1             pyh145f28c_0    conda-forge
platformdirs              4.3.6              pyhd8ed1ab_0    conda-forge
prompt-toolkit            3.0.48             pyha770c72_0    conda-forge
psutil                    6.1.0           py313h536fd9c_0    conda-forge
ptyprocess                0.7.0              pyhd3deb0d_0    conda-forge
pure_eval                 0.2.3              pyhd8ed1ab_0    conda-forge
pygments                  2.18.0             pyhd8ed1ab_0    conda-forge
python                    3.13.0          h9ebbce0_100_cp313    conda-forge
python-dateutil           2.9.0              pyhd8ed1ab_0    conda-forge
python_abi                3.13                    5_cp313    conda-forge
pyzmq                     26.2.0          py313h8e95178_3    conda-forge
re2                       2024.07.02           h77b4e00_1    conda-forge
readline                  8.2                  h8228510_1    conda-forge
scipy                     1.14.1          py313h27c5614_1    conda-forge
six                       1.16.0             pyh6c4a22f_0    conda-forge
stack_data                0.6.2              pyhd8ed1ab_0    conda-forge
sysroot_linux-64          2.17                h4a8ded7_18    conda-forge
tk                        8.6.13          noxft_h4845f30_101    conda-forge
tornado                   6.4.1           py313h536fd9c_1    conda-forge
traitlets                 5.14.3             pyhd8ed1ab_0    conda-forge
typing_extensions         4.12.2             pyha770c72_0    conda-forge
tzdata                    2024b                hc8b5060_0    conda-forge
wcwidth                   0.2.13             pyhd8ed1ab_0    conda-forge
xz                        5.2.6                h166bdaf_0    conda-forge
zeromq                    4.3.5                h3b0a872_6    conda-forge
zipp                      3.20.2             pyhd8ed1ab_0    conda-forge

pip installation:

jax:    0.4.35
jaxlib: 0.4.34
numpy:  2.1.2
python: 3.13.0 | packaged by conda-forge | (main, Oct  8 2024, 20:04:32) [GCC 13.3.0]
device info: NVIDIA GeForce RTX 3090-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='Merlin', release='6.8.0-47-generic', version='#47-Ubuntu SMP PREEMPT_DYNAMIC Fri Sep 27 21:40:26 UTC 2024', machine='x86_64')

$ nvidia-smi
Tue Oct 29 22:50:01 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.120                Driver Version: 550.120        CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA GeForce RTX 3090        Off |   00000000:01:00.0  On |                  N/A |
|  0%   48C    P3             84W /  420W |    1310MiB /  24576MiB |     20%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A      2582      G   /usr/lib/xorg/Xorg                            432MiB |
|    0   N/A  N/A      2850      G   /usr/bin/gnome-shell                          108MiB |
|    0   N/A  N/A      4140      G   ...irefox/4173/usr/lib/firefox/firefox        151MiB |
|    0   N/A  N/A      5602      G   ...erProcess --variations-seed-version        141MiB |
|    0   N/A  N/A     28658      G   /usr/bin/nautilus                              41MiB |
|    0   N/A  N/A     28840      G   /usr/bin/gnome-text-editor                     12MiB |
|    0   N/A  N/A     36942      G   /usr/bin/nvidia-settings                        0MiB |
|    0   N/A  N/A     42238      C   .../miniforge3/envs/jax_ml6/bin/python        256MiB |
+-----------------------------------------------------------------------------------------+

# Name                    Version                   Build  Channel
_libgcc_mutex             0.1                 conda_forge    conda-forge
_openmp_mutex             4.5                       2_gnu    conda-forge
asttokens                 2.4.1              pyhd8ed1ab_0    conda-forge
brotli                    1.1.0                hb9d3cd8_2    conda-forge
brotli-bin                1.1.0                hb9d3cd8_2    conda-forge
bzip2                     1.0.8                h4bc722e_7    conda-forge
ca-certificates           2024.8.30            hbcca054_0    conda-forge
certifi                   2024.8.30          pyhd8ed1ab_0    conda-forge
comm                      0.2.2              pyhd8ed1ab_0    conda-forge
contourpy                 1.3.0           py313h33d0bda_2    conda-forge
cycler                    0.12.1             pyhd8ed1ab_0    conda-forge
debugpy                   1.8.7           py313h46c70d0_0    conda-forge
decorator                 5.1.1              pyhd8ed1ab_0    conda-forge
exceptiongroup            1.2.2              pyhd8ed1ab_0    conda-forge
executing                 2.1.0              pyhd8ed1ab_0    conda-forge
fonttools                 4.54.1          py313h8060acc_1    conda-forge
freetype                  2.12.1               h267a509_2    conda-forge
importlib-metadata        8.5.0              pyha770c72_0    conda-forge
ipykernel                 6.29.5             pyh3099207_0    conda-forge
ipympl                    0.9.4              pyhd8ed1ab_0    conda-forge
ipython                   8.29.0             pyh707e725_0    conda-forge
ipython_genutils          0.2.0              pyhd8ed1ab_1    conda-forge
ipywidgets                8.1.5              pyhd8ed1ab_0    conda-forge
jax                       0.4.35                   pypi_0    pypi
jax-cuda12-pjrt           0.4.35                   pypi_0    pypi
jax-cuda12-plugin         0.4.35                   pypi_0    pypi
jaxlib                    0.4.34                   pypi_0    pypi
jedi                      0.19.1             pyhd8ed1ab_0    conda-forge
jupyter_client            8.6.3              pyhd8ed1ab_0    conda-forge
jupyter_core              5.7.2              pyh31011fe_1    conda-forge
jupyterlab_widgets        3.0.13             pyhd8ed1ab_0    conda-forge
keyutils                  1.6.1                h166bdaf_0    conda-forge
kiwisolver                1.4.7           py313h33d0bda_0    conda-forge
krb5                      1.21.3               h659f571_0    conda-forge
lcms2                     2.16                 hb7c19ff_0    conda-forge
ld_impl_linux-64          2.43                 h712a8e2_2    conda-forge
lerc                      4.0.0                h27087fc_0    conda-forge
libblas                   3.9.0           25_linux64_openblas    conda-forge
libbrotlicommon           1.1.0                hb9d3cd8_2    conda-forge
libbrotlidec              1.1.0                hb9d3cd8_2    conda-forge
libbrotlienc              1.1.0                hb9d3cd8_2    conda-forge
libcblas                  3.9.0           25_linux64_openblas    conda-forge
libdeflate                1.22                 hb9d3cd8_0    conda-forge
libedit                   3.1.20191231         he28a2e2_2    conda-forge
libexpat                  2.6.3                h5888daf_0    conda-forge
libffi                    3.4.2                h7f98852_5    conda-forge
libgcc                    14.2.0               h77fa898_1    conda-forge
libgcc-ng                 14.2.0               h69a702a_1    conda-forge
libgfortran               14.2.0               h69a702a_1    conda-forge
libgfortran-ng            14.2.0               h69a702a_1    conda-forge
libgfortran5              14.2.0               hd5240d6_1    conda-forge
libgomp                   14.2.0               h77fa898_1    conda-forge
libjpeg-turbo             3.0.0                hd590300_1    conda-forge
liblapack                 3.9.0           25_linux64_openblas    conda-forge
libmpdec                  4.0.0                h4bc722e_0    conda-forge
libopenblas               0.3.28          pthreads_h94d23a6_0    conda-forge
libpng                    1.6.44               hadc24fc_0    conda-forge
libsodium                 1.0.20               h4ab18f5_0    conda-forge
libsqlite                 3.47.0               hadc24fc_1    conda-forge
libstdcxx                 14.2.0               hc0a3c3a_1    conda-forge
libstdcxx-ng              14.2.0               h4852527_1    conda-forge
libtiff                   4.7.0                he137b08_1    conda-forge
libuuid                   2.38.1               h0b41bf4_0    conda-forge
libwebp-base              1.4.0                hd590300_0    conda-forge
libxcb                    1.17.0               h8a09558_0    conda-forge
libzlib                   1.3.1                hb9d3cd8_2    conda-forge
matplotlib-base           3.9.2           py313h129903b_1    conda-forge
matplotlib-inline         0.1.7              pyhd8ed1ab_0    conda-forge
ml-dtypes                 0.5.0                    pypi_0    pypi
munkres                   1.1.4              pyh9f0ad1d_0    conda-forge
ncurses                   6.5                  he02047a_1    conda-forge
nest-asyncio              1.6.0              pyhd8ed1ab_0    conda-forge
numpy                     2.1.2           py313h4bf6692_0    conda-forge
nvidia-cublas-cu12        12.6.3.3                 pypi_0    pypi
nvidia-cuda-cupti-cu12    12.6.80                  pypi_0    pypi
nvidia-cuda-nvcc-cu12     12.6.77                  pypi_0    pypi
nvidia-cuda-runtime-cu12  12.6.77                  pypi_0    pypi
nvidia-cudnn-cu12         9.5.1.17                 pypi_0    pypi
nvidia-cufft-cu12         11.3.0.4                 pypi_0    pypi
nvidia-cusolver-cu12      11.7.1.2                 pypi_0    pypi
nvidia-cusparse-cu12      12.5.4.2                 pypi_0    pypi
nvidia-nccl-cu12          2.23.4                   pypi_0    pypi
nvidia-nvjitlink-cu12     12.6.77                  pypi_0    pypi
openjpeg                  2.5.2                h488ebb8_0    conda-forge
openssl                   3.3.2                hb9d3cd8_0    conda-forge
opt-einsum                3.4.0                    pypi_0    pypi
packaging                 24.1               pyhd8ed1ab_0    conda-forge
parso                     0.8.4              pyhd8ed1ab_0    conda-forge
pexpect                   4.9.0              pyhd8ed1ab_0    conda-forge
pickleshare               0.7.5                   py_1003    conda-forge
pillow                    11.0.0          py313h2d7ed13_0    conda-forge
pip                       24.3.1             pyh145f28c_0    conda-forge
platformdirs              4.3.6              pyhd8ed1ab_0    conda-forge
prompt-toolkit            3.0.48             pyha770c72_0    conda-forge
psutil                    6.1.0           py313h536fd9c_0    conda-forge
pthread-stubs             0.4               hb9d3cd8_1002    conda-forge
ptyprocess                0.7.0              pyhd3deb0d_0    conda-forge
pure_eval                 0.2.3              pyhd8ed1ab_0    conda-forge
pygments                  2.18.0             pyhd8ed1ab_0    conda-forge
pyparsing                 3.2.0              pyhd8ed1ab_1    conda-forge
python                    3.13.0          h9ebbce0_100_cp313    conda-forge
python-dateutil           2.9.0              pyhd8ed1ab_0    conda-forge
python_abi                3.13                    5_cp313    conda-forge
pyzmq                     26.2.0          py313h8e95178_3    conda-forge
qhull                     2020.2               h434a139_5    conda-forge
readline                  8.2                  h8228510_1    conda-forge
scipy                     1.14.1                   pypi_0    pypi
six                       1.16.0             pyh6c4a22f_0    conda-forge
stack_data                0.6.2              pyhd8ed1ab_0    conda-forge
tk                        8.6.13          noxft_h4845f30_101    conda-forge
tornado                   6.4.1           py313h536fd9c_1    conda-forge
traitlets                 5.14.3             pyhd8ed1ab_0    conda-forge
typing_extensions         4.12.2             pyha770c72_0    conda-forge
tzdata                    2024b                hc8b5060_0    conda-forge
wcwidth                   0.2.13             pyhd8ed1ab_0    conda-forge
widgetsnbextension        4.0.13             pyhd8ed1ab_0    conda-forge
xorg-libxau               1.0.11               hb9d3cd8_1    conda-forge
xorg-libxdmcp             1.1.5                hb9d3cd8_0    conda-forge
xz                        5.2.6                h166bdaf_0    conda-forge
zeromq                    4.3.5                h3b0a872_6    conda-forge
zipp                      3.20.2             pyhd8ed1ab_0    conda-forge
zstd                      1.5.6                ha6fb4c9_0    conda-forge
hawkinsp commented 1 month ago

We (JAX) only support the pip packages, which should work under conda as well. For the native conda-forge packages, you'd need to follow up on the conda-forge feedstock project for jax.

hartikainen commented 3 weeks ago

I briefly looked into this a while ago as well when working on #24139 and #24684. I don't have time or interest to fully resolve the issue as I'm not a conda user myself, but here are a couple of pointers from my notes that might help you debug this.

First, in _src/xla_bridge.py, the GPU backend doesn't get initialized (I believe) because of outdated cu{BLAS,SPARSE} modules as noted by the error messages which by default gets suppressed:

(Pdb) print(err_msg)
Unable to initialize backend 'cuda': Unable to use CUDA because of the following issues with CUDA components:
Outdated cuBLAS installation found.
Version JAX was built against: 120001
Minimum supported: 120100
Installed version: 120002
The local installation version must be no lower than 120100.
--------------------------------------------------
Outdated cuSPARSE installation found.
Version JAX was built against: 12000
Minimum supported: 12100
Installed version: 12001
The local installation version must be no lower than 12100.

Even if you fix these, you might have to explicitly pass in CUDA_PATH environment variable so that the lib.cuda_path comes out correctly. There are two branches in that function, one checking CUDA_PATH and another one checking nvidia.cuda_nvcc module path. Despite installing cuda-nvcc via conda, the nvidia.cuda_nvcc module is not available and thus the second branch doesn't work.