Open zairving opened 1 month ago
We (JAX) only support the pip
packages, which should work under conda
as well. For the native conda-forge
packages, you'd need to follow up on the conda-forge feedstock project for jax
.
I briefly looked into this a while ago as well when working on #24139 and #24684. I don't have time or interest to fully resolve the issue as I'm not a conda user myself, but here are a couple of pointers from my notes that might help you debug this.
First, in _src/xla_bridge.py
, the GPU backend doesn't get initialized (I believe) because of outdated cu{BLAS,SPARSE}
modules as noted by the error messages which by default gets suppressed:
(Pdb) print(err_msg)
Unable to initialize backend 'cuda': Unable to use CUDA because of the following issues with CUDA components:
Outdated cuBLAS installation found.
Version JAX was built against: 120001
Minimum supported: 120100
Installed version: 120002
The local installation version must be no lower than 120100.
--------------------------------------------------
Outdated cuSPARSE installation found.
Version JAX was built against: 12000
Minimum supported: 12100
Installed version: 12001
The local installation version must be no lower than 12100.
Even if you fix these, you might have to explicitly pass in CUDA_PATH
environment variable so that the lib.cuda_path
comes out correctly. There are two branches in that function, one checking CUDA_PATH
and another one checking nvidia.cuda_nvcc
module path. Despite installing cuda-nvcc
via conda
, the nvidia.cuda_nvcc
module is not available and thus the second branch doesn't work.
Description
I previously had a working installation of JAX (installed via conda) that recognised my NVIDIA GPU without issue. However, I recently migrated to a new machine and now I cannot get JAX to recognise my GPU when I install via conda. I'm using Miniforge to manage my conda environments, as I did on my old machine, and I installed JAX according to the docs:
conda install jaxlib=*=*cuda* jax cuda-nvcc -c conda-forge -c nvidia
When I then try to import JAX and check my available devices using:I get the following output:
tensorflow, however, does recognise my GPU, and so I tried the suggestion from #15268 to install using pip. I created a new environment and ran:
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
When I then ran my JAX code above I got the output:Voilà, my GPU has been found!
It therefore appears that the conda section of the docs might need updating.
System info (python version, jaxlib version, accelerator, etc.)
Conda installation:
pip installation: