Closed seishiroono closed 3 months ago
Can you describe exactly how you installed jax (and any relevant drivers, etc.) so that we can try to reproduce?
@dfm Thanks for your message. I am using miniconda for the virtual environment. The following is what I did in a virtual enviornment.
>>> conda create -n jax_cuda12_test python=3.11.7
>>> conda activate jax_cuda12_test
>>> python3 -m pip install nvidia-cuda-cccl-cu12==12.4.127 nvidia-cuda-cupti-cu12==12.4.127 nvidia-cuda-nvcc-cu12==12.4.131 nvidia-cuda-opencl-cu12==12.4.127 nvidia-cuda-nvrtc-cu12==12.4.127 nvidia-cublas-cu12==12.4.5.8 nvidia-cuda-sanitizer-api-cu12==12.4.127 nvidia-cufft-cu12 nvidia-curand-cu12 nvidia-cusolver-cu12==11.4.5.107 nvidia-cusparse-cu12==12.4.1.24 nvidia-npp-cu12 nvidia-nvfatbin-cu12==12.4.127 nvidia-nvjitlink-cu12==12.4.127 nvidia-nvjpeg-cu12 nvidia-nvml-dev-cu12==12.4.127 nvidia-nvtx-cu12==12.4.127 nvidia-cuda-runtime-cu12==12.4.127
>>> pip install --upgrade pip
>>> pip install --upgrade "jax[cuda12]"
The resulting conda list
is as follows.
# Name Version Build Channel
_libgcc_mutex 0.1 main
_openmp_mutex 5.1 1_gnu
bzip2 1.0.8 h5eee18b_6
ca-certificates 2024.3.11 h06a4308_0
jax 0.4.29 pypi_0 pypi
jax-cuda12-pjrt 0.4.29 pypi_0 pypi
jax-cuda12-plugin 0.4.29 pypi_0 pypi
jaxlib 0.4.29 pypi_0 pypi
ld_impl_linux-64 2.38 h1181459_1
libffi 3.4.4 h6a678d5_1
libgcc-ng 11.2.0 h1234567_1
libgomp 11.2.0 h1234567_1
libstdcxx-ng 11.2.0 h1234567_1
libuuid 1.41.5 h5eee18b_0
ml-dtypes 0.4.0 pypi_0 pypi
ncurses 6.4 h6a678d5_0
numpy 1.26.4 pypi_0 pypi
nvidia-cublas-cu12 12.4.5.8 pypi_0 pypi
nvidia-cuda-cccl-cu12 12.4.127 pypi_0 pypi
nvidia-cuda-cupti-cu12 12.4.127 pypi_0 pypi
nvidia-cuda-nvcc-cu12 12.4.131 pypi_0 pypi
nvidia-cuda-nvrtc-cu12 12.4.127 pypi_0 pypi
nvidia-cuda-opencl-cu12 12.4.127 pypi_0 pypi
nvidia-cuda-runtime-cu12 12.4.127 pypi_0 pypi
nvidia-cuda-sanitizer-api-cu12 12.4.127 pypi_0 pypi
nvidia-cudnn-cu12 9.1.1.17 pypi_0 pypi
nvidia-cufft-cu12 11.2.3.18 pypi_0 pypi
nvidia-curand-cu12 10.3.6.39 pypi_0 pypi
nvidia-cusolver-cu12 11.4.5.107 pypi_0 pypi
nvidia-cusparse-cu12 12.4.1.24 pypi_0 pypi
nvidia-nccl-cu12 2.21.5 pypi_0 pypi
nvidia-npp-cu12 12.3.0.116 pypi_0 pypi
nvidia-nvfatbin-cu12 12.4.127 pypi_0 pypi
nvidia-nvjitlink-cu12 12.4.127 pypi_0 pypi
nvidia-nvjpeg-cu12 12.3.2.38 pypi_0 pypi
nvidia-nvml-dev-cu12 12.4.127 pypi_0 pypi
nvidia-nvtx-cu12 12.4.127 pypi_0 pypi
openssl 3.0.13 h7f8727e_2
opt-einsum 3.3.0 pypi_0 pypi
pip 24.0 pypi_0 pypi
python 3.11.7 h955ad1f_0
readline 8.2 h5eee18b_0
scipy 1.13.1 pypi_0 pypi
setuptools 69.5.1 pypi_0 pypi
sqlite 3.45.3 h5eee18b_0
tk 8.6.14 h39e8969_0
tzdata 2024a h04d1e81_0
wheel 0.43.0 pypi_0 pypi
xz 5.4.6 h5eee18b_1
zlib 1.2.13 h5eee18b_1
I'll look into this a little later, but you shouldn't need to install all those nvidia
pip packages manually. What happens if you just pip install jax[cuda12]
in a fresh environment?
@dfm Thanks for your message. In a new environment, I did pip install --upgrade "jax[cuda12]"
. It looks like pip install --upgrade "jax[cuda12]"
cannot install collect versions of nvidia
packages. Actually, the reason why I installed nvidia
myself is that I met the same warning.
>>> import jax.numpy as jnp
>>> a = jnp.array(1.)
2024-06-19 10:47:17.191520: W external/xla/xla/service/gpu/nvptx_compiler.cc:765] The NVIDIA driver's CUDA version is 12.4 which is older than the ptxas CUDA version (12.5.40). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.
On the other hand, the original error did not appear.
>>> import jax.numpy as jnp
>>> c = jnp.array([[ 0., 0., 0., 1.],[-0., 0., -1., 0.],[ 0., 1., 0., 0.],[-1., 0., -0., 0.]])
>>> jnp.linalg.det(c)
Array(1., dtype=float32)
Just in case, I also show my conda list
.
_libgcc_mutex 0.1 main
_openmp_mutex 5.1 1_gnu
bzip2 1.0.8 h5eee18b_6
ca-certificates 2024.3.11 h06a4308_0
jax 0.4.30 pypi_0 pypi
jax-cuda12-pjrt 0.4.30 pypi_0 pypi
jax-cuda12-plugin 0.4.30 pypi_0 pypi
jaxlib 0.4.30 pypi_0 pypi
ld_impl_linux-64 2.38 h1181459_1
libffi 3.4.4 h6a678d5_1
libgcc-ng 11.2.0 h1234567_1
libgomp 11.2.0 h1234567_1
libstdcxx-ng 11.2.0 h1234567_1
libuuid 1.41.5 h5eee18b_0
ml-dtypes 0.4.0 pypi_0 pypi
ncurses 6.4 h6a678d5_0
numpy 2.0.0 pypi_0 pypi
nvidia-cublas-cu12 12.5.2.13 pypi_0 pypi
nvidia-cuda-cupti-cu12 12.5.39 pypi_0 pypi
nvidia-cuda-nvcc-cu12 12.5.40 pypi_0 pypi
nvidia-cuda-runtime-cu12 12.5.39 pypi_0 pypi
nvidia-cudnn-cu12 9.1.1.17 pypi_0 pypi
nvidia-cufft-cu12 11.2.3.18 pypi_0 pypi
nvidia-cusolver-cu12 11.6.2.40 pypi_0 pypi
nvidia-cusparse-cu12 12.4.1.24 pypi_0 pypi
nvidia-nccl-cu12 2.21.5 pypi_0 pypi
nvidia-nvjitlink-cu12 12.5.40 pypi_0 pypi
openssl 3.0.14 h5eee18b_0
opt-einsum 3.3.0 pypi_0 pypi
pip 24.0 pypi_0 pypi
python 3.11.7 h955ad1f_0
readline 8.2 h5eee18b_0
scipy 1.13.1 pypi_0 pypi
setuptools 69.5.1 pypi_0 pypi
sqlite 3.45.3 h5eee18b_0
tk 8.6.14 h39e8969_0
tzdata 2024a h04d1e81_0
wheel 0.43.0 pypi_0 pypi
xz 5.4.6 h5eee18b_1
zlib 1.2.13 h5eee18b_1
P.S. Yesterday, JAX was updated, so the result of jax.print_environment_info()
was also changed.
>>> jax.print_environment_info()
jax: 0.4.30
jaxlib: 0.4.30
numpy: 2.0.0
python: 3.11.7 (main, Dec 15 2023, 18:12:31) [GCC 11.2.0]
jax.devices (2 total, 2 local): [cuda(id=0) cuda(id=1)]
process_count: 1
platform: uname_result(system='Linux', node='gnode05', release='3.10.0-1160.53.1.el7.x86_64', version='#1 SMP Fri Jan 14 13:59:45 UTC 2022', machine='x86_64')
$ nvidia-smi
Wed Jun 19 10:45:13 2024
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15 Driver Version: 550.54.15 CUDA Version: 12.4 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA A100-PCIE-40GB Off | 00000000:2B:00.0 Off | 0 |
| N/A 28C P0 36W / 250W | 574MiB / 40960MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+------------------------+----------------------+
| 1 NVIDIA A100-PCIE-40GB Off | 00000000:A2:00.0 Off | 0 |
| N/A 27C P0 36W / 250W | 425MiB / 40960MiB | 3% Default |
| | | Disabled |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
| 0 N/A N/A 14770 C python 416MiB |
| 0 N/A N/A 15535 G /usr/bin/X 108MiB |
| 0 N/A N/A 15637 G /usr/bin/gnome-shell 22MiB |
| 1 N/A N/A 14770 C python 416MiB |
+-----------------------------------------------------------------------------------------+
Thanks for the info. I can reproduce the warning that you're seeing about the ptxas version. I was able to work around this by simply downgrading the nvidia-cuda-nvcc-cu12
pip package. So, from a fresh virtual environment, I was able to get a working installation with:
pip install "jax[cuda12]" "nvidia-cuda-nvcc-cu12<12.5"
Want to see if that fixes the issue for you?
Edited to add: I also wouldn't worry too much about the warning. It may make jit
compilation a little bit slower, but I wouldn't expect it to be a major issue!
@dfm Thank you for your reply. The command you provided works for my environment. My program seems to run correctly. Let me check further if nothing happens.
I'm going to close this as completed - please feel free to comment if there are other issues.
Description
Description
I am trying to use JAX version 0.4.29 with CUDA 12.4. When I computed a simple linear algebraic calculation, I got an error
RuntimeError: jaxlib/gpu/solver_kernels.cc:45: operation gpusolverDnCreate(&handle) failed: cuSolver internal error
.Error
When I did the following, I found the above error
On the other hand, when I tried the following command, it works well.
System info (python version, jaxlib version, accelerator, etc.)
System info (python version, jaxlib version, accelerator, etc.)