Open carlosgmartin opened 1 week ago
I also tried uninstalling all existing nvidia and jax-cuda packages before re-installing JAX:
$ python3 -m pip freeze --all | grep -e nvidia -e jax-cuda | xargs python3 -m pip uninstall -y jax jaxlib
...
$ conda list | awk '{ print $1 }' | grep -e nvidia -e jax-cuda | xargs conda remove -y jax jaxlib
...
$ mamba list | awk '{ print $1 }' | grep -e nvidia -e jax-cuda | xargs mamba remove -y jax jaxlib
...
$ python3 -m pip install --upgrade "jax[cuda12]"
...
$ echo $CUDA_VISIBLE_DEVICES
0,1,2,3,4,5,6,7
$ echo $LD_LIBRARY_PATH
$ python3 -c "import jax; jax.numpy.array(0)"
RuntimeError: Unable to initialize backend 'cuda': FAILED_PRECONDITION: No visible GPU devices. (you may need to uninstall the failing plugin package, or set JAX_PLATFORMS=cpu to skip this backend.)
but still get the same error message.
Description
I used
to install JAX on a GPU node, but am getting a
CUDA_ERROR_SYSTEM_NOT_READY
error:Here's some additional output:
System info (python version, jaxlib version, accelerator, etc.)