conda-forge / jaxlib-feedstock

A conda-smithy repository for jaxlib.
BSD 3-Clause "New" or "Revised" License
16 stars 24 forks source link

CUDA not working with jax 0.4.31 as cudnn >= 9.1 is required with cudnn 8.9.7 is installed #277

Closed traversaro closed 1 month ago

traversaro commented 2 months ago

Comment:

In jax==0.4.31, the runtime requirement for cudnn changed to be cudnn>=9.1 but jaxlib in conda-forge does not have this constraint, so now the latest jax does not work with CUDA:

Python 3.12.4 | packaged by conda-forge | (main, Jun 17 2024, 10:23:07) [GCC 12.3.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax
>>> jax.devices()
CUDA backend failed to initialize: Unable to use CUDA because of the following issues with CUDA components:
Outdated cuDNN installation found.
Version JAX was built against: 8907
Minimum supported: 9100
Installed version: 8907
The local installation version must be no lower than 9100..(Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
[CpuDevice(id=0)]
Click for conda list ~~~ (jax) traversaro@IITBMP014LW012:~$ conda list # packages in environment at /home/traversaro/miniforge3/envs/jax: # # Name Version Build Channel _libgcc_mutex 0.1 conda_forge conda-forge _openmp_mutex 4.5 2_gnu conda-forge _sysroot_linux-64_curr_repodata_hack 3 h69a702a_16 conda-forge binutils_impl_linux-64 2.40 ha1999f0_7 conda-forge binutils_linux-64 2.40 hb3c18ed_0 conda-forge bzip2 1.0.8 h4bc722e_7 conda-forge c-ares 1.32.3 h4bc722e_0 conda-forge ca-certificates 2024.7.4 hbcca054_0 conda-forge cuda-cccl_linux-64 12.5.39 ha770c72_0 conda-forge cuda-crt-dev_linux-64 12.5.82 ha770c72_0 conda-forge cuda-crt-tools 12.5.82 ha770c72_0 conda-forge cuda-cudart 12.5.82 he02047a_0 conda-forge cuda-cudart-dev 12.5.82 he02047a_0 conda-forge cuda-cudart-dev_linux-64 12.5.82 h85509e4_0 conda-forge cuda-cudart-static 12.5.82 he02047a_0 conda-forge cuda-cudart-static_linux-64 12.5.82 h85509e4_0 conda-forge cuda-cudart_linux-64 12.5.82 h85509e4_0 conda-forge cuda-cupti 12.5.82 he02047a_0 conda-forge cuda-driver-dev_linux-64 12.5.82 h85509e4_0 conda-forge cuda-nvcc 12.5.82 hcdd1206_0 conda-forge cuda-nvcc-dev_linux-64 12.5.82 ha770c72_0 conda-forge cuda-nvcc-impl 12.5.82 hd3aeb46_0 conda-forge cuda-nvcc-tools 12.5.82 hd3aeb46_0 conda-forge cuda-nvcc_linux-64 12.5.82 h8a487aa_0 conda-forge cuda-nvrtc 12.5.82 he02047a_0 conda-forge cuda-nvtx 12.5.82 he02047a_0 conda-forge cuda-nvvm-dev_linux-64 12.5.82 ha770c72_0 conda-forge cuda-nvvm-impl 12.5.82 h59595ed_0 conda-forge cuda-nvvm-tools 12.5.82 h59595ed_0 conda-forge cuda-version 12.5 hd4f0392_3 conda-forge cudnn 8.9.7.29 h092f7fd_3 conda-forge gcc_impl_linux-64 13.3.0 hfea6d02_0 conda-forge gcc_linux-64 13.3.0 hc28eda2_0 conda-forge gxx_impl_linux-64 13.3.0 hffce095_0 conda-forge gxx_linux-64 13.3.0 h6834431_0 conda-forge importlib-metadata 8.2.0 pyha770c72_0 conda-forge importlib_metadata 8.2.0 hd8ed1ab_0 conda-forge jax 0.4.31 pyhd8ed1ab_0 conda-forge jaxlib 0.4.30 cuda120py312h4008524_200 conda-forge kernel-headers_linux-64 3.10.0 h4a8ded7_16 conda-forge ld_impl_linux-64 2.40 hf3520f5_7 conda-forge libabseil 20240116.2 cxx17_he02047a_1 conda-forge libblas 3.9.0 23_linux64_openblas conda-forge libcblas 3.9.0 23_linux64_openblas conda-forge libcublas 12.5.3.2 he02047a_0 conda-forge libcufft 11.2.3.61 he02047a_0 conda-forge libcurand 10.3.6.82 he02047a_0 conda-forge libcusolver 11.6.3.83 he02047a_0 conda-forge libcusparse 12.5.1.3 he02047a_0 conda-forge libexpat 2.6.2 h59595ed_0 conda-forge libffi 3.4.2 h7f98852_5 conda-forge libgcc-devel_linux-64 13.3.0 h84ea5a7_100 conda-forge libgcc-ng 14.1.0 h77fa898_0 conda-forge libgfortran-ng 14.1.0 h69a702a_0 conda-forge libgfortran5 14.1.0 hc5f4f2c_0 conda-forge libgomp 14.1.0 h77fa898_0 conda-forge libgrpc 1.62.2 h15f2491_0 conda-forge liblapack 3.9.0 23_linux64_openblas conda-forge libnsl 2.0.1 hd590300_0 conda-forge libnvjitlink 12.5.82 he02047a_0 conda-forge libopenblas 0.3.27 pthreads_hac2b453_1 conda-forge libprotobuf 4.25.3 h08a7969_0 conda-forge libre2-11 2023.09.01 h5a48ba9_2 conda-forge libsanitizer 13.3.0 heb74ff8_0 conda-forge libsqlite 3.46.0 hde9e2c9_0 conda-forge libstdcxx-devel_linux-64 13.3.0 h84ea5a7_100 conda-forge libstdcxx-ng 14.1.0 hc0a3c3a_0 conda-forge libuuid 2.38.1 h0b41bf4_0 conda-forge libxcrypt 4.4.36 hd590300_1 conda-forge libzlib 1.3.1 h4ab18f5_1 conda-forge ml_dtypes 0.4.0 py312h1d6d2e6_1 conda-forge nccl 2.22.3.1 hbc370b7_1 conda-forge ncurses 6.5 h59595ed_0 conda-forge numpy 2.0.1 py312h1103770_0 conda-forge openssl 3.3.1 h4bc722e_2 conda-forge opt-einsum 3.3.0 hd8ed1ab_2 conda-forge opt_einsum 3.3.0 pyhc1e730c_2 conda-forge pip 24.1.2 pyhd8ed1ab_0 conda-forge python 3.12.4 h194c7f8_0_cpython conda-forge python_abi 3.12 4_cp312 conda-forge re2 2023.09.01 h7f4b329_2 conda-forge readline 8.2 h8228510_1 conda-forge scipy 1.14.0 py312hc2bc53b_1 conda-forge sysroot_linux-64 2.17 h4a8ded7_16 conda-forge tk 8.6.13 noxft_h4845f30_101 conda-forge tzdata 2024a h0c530f3_0 conda-forge xz 5.2.6 h166bdaf_0 conda-forge zipp 3.19.2 pyhd8ed1ab_0 conda-forge ~~~

Related lines of code:

Related jax PRs:

traversaro commented 2 months ago

For now, a simple workaround is to install an earlier version of jaxlib instead of using 0.4.30 .

Originally reported in https://github.com/google/jax/issues/18032#issuecomment-2259383091 by @lucascolley .

traversaro commented 2 months ago

The problem is that on conda-forge there are no cudnn 9, see .

Furthermore, there is something that I am missing, as it seems that the cudnn actualyl required is 8.9 in jaxlib 0.4.30, see https://github.com/google/jax/blob/jaxlib-v0.4.30/jax/_src/xla_bridge.py#L375 .

flferretti commented 2 months ago

The problem is that on conda-forge there are no cudnn 9, see .

Furthermore, there is something that I am missing, as it seems that the cudnn actualyl required is 8.9 in jaxlib 0.4.30, see https://github.com/google/jax/blob/jaxlib-v0.4.30/jax/_src/xla_bridge.py#L375 .

Here it says 9100 https://github.com/google/jax/blob/35ba6f78bb23d11e59b10c15455f63997f4e7124/jax/_src/xla_bridge.py#L384 as changed in https://github.com/google/jax/commit/d1c0d993fc97107d7e1cffa0c7bb8f3f2217095f

traversaro commented 2 months ago

Furthermore, there is something that I am missing, as it seems that the cudnn actualyl required is 8.9 in jaxlib 0.4.30, see https://github.com/google/jax/blob/jaxlib-v0.4.30/jax/_src/xla_bridge.py#L375 .

Ah, I got it, the check is actually on jax, and so the problem is not of using jaxlib 0.4.30, but actually jax 0.4.31. So the proper workaround is to install a jax version earlier or equal to 0.4.28 (as that is the first version earlier then 0.4.31 available in conda-forge).

traversaro commented 2 months ago

So the regression was in https://github.com/conda-forge/jax-feedstock/pull/148 .

traversaro commented 2 months ago

So to solve this we need https://github.com/conda-forge/cudnn-feedstock/pull/83 .

lucascolley commented 2 months ago

thanks for the quick action! much appreciated :)

traversaro commented 2 months ago

So to solve this we need conda-forge/cudnn-feedstock#83 .

That PR was closed, the new PR that provides something similar is https://github.com/conda-forge/cudnn-feedstock/pull/84 .

ngam commented 1 month ago

Fixed now

lucascolley commented 1 month ago

thanks all! Will test this out with SciPy

traversaro commented 1 month ago

Fixed now

Thanks a lot!

rgommers commented 1 month ago

This doesn't look fixed to me, since when @lucascolley tried to unpin jax in our SciPy deve environment it yields this combo of packages:

% pixi ls -e array-api-cuda --platform linux-64 | rg "cudnn|jax"
Environment: array-api-cuda
cudnn                                 8.9.7.29     h092f7fd_3                 446.6 MiB  conda  cudnn-8.9.7.29-h092f7fd_3.conda
jax                                   0.4.31       pyhd8ed1ab_0               1.3 MiB    conda  jax-0.4.31-pyhd8ed1ab_0.conda
jaxlib                                0.4.31       cuda120py312h4008524_200   89.2 MiB   conda  jaxlib-0.4.31-cuda120py312h4008524_200.conda

So the marking of jaxlib 0.4.31 as broken or the constraint of cuDNN>=9.1 isn't applied.

ngam commented 1 month ago

linux-64/jaxlib-0.4.31-cuda120py312h4008524_200.conda was marked broken in https://github.com/conda-forge/admin-requests/pull/1050 (see). The SciPy dev env is picking up packages marked as broken, which seems like a problem on the SciPy side than here. I recommend an audit there.

On a machine with 4 A100s, micromamba create -n test_jaxs jax resolves

  + cudnn                                   9.2.1.18  hbc370b7_0                conda-forge     Cached
  + jaxlib                                    0.4.31  cuda120py312h3995614_201  conda-forge     Cached
  + jax                                       0.4.31  pyhd8ed1ab_1              conda-forge     Cached

On the same machine, micromamba create -n test_jaxs jax==0.4.31=*pyhd8ed1ab_0* (which is the quoted jax above) resolves

  + cudnn                                   9.2.1.18  hbc370b7_0                conda-forge     Cached
  + jaxlib                                    0.4.31  cuda120py312h3995614_201  conda-forge     Cached
  + jax                                       0.4.31  pyhd8ed1ab_0              conda-forge        1MB

Creating one of the above two envs, I can verify the quoted error above (in the main post) is now resolved

Python 3.12.5 | packaged by conda-forge | (main, Aug  8 2024, 18:36:51) [GCC 12.4.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax
>>> jax.devices()
[CudaDevice(id=0), CudaDevice(id=1), CudaDevice(id=2), CudaDevice(id=3)]
>>> 
rgommers commented 1 month ago

which seems like a problem on the SciPy side than here

It's very well possible that it's a pixi problem (not a SciPy problem) - that's being discussed at https://github.com/rgommers/pixi-dev-scipystack/pull/5).

However, this isn't quite convincing:

... micromamba create -n test_jaxs jax resolves

That's too easy, since the latest jax and the latest cudnn are compatible, so it's obviously going to be the first thing the solver tries. The problematic environment uses CUDA builds of both jax and pytorch, and that's when you'll see if the issue is resolved or not since they have incompatible cudnn constraints. Could you try on that same machine to solve again, including pytorch (I don't have a Linux machine at hand, sorry)? I think you should get latest pytorch, and the last jax version (0.4.28) that supports the same cudnn version that pytorch supports.

traversaro commented 1 month ago

I think you should get latest pytorch, and the last jax version (0.4.28) that supports the same cudnn version that pytorch supports.

I am afraid that this will not happen, as the solver is not aware that you want cuda versions of pytorch and jax, so it may install the latest cpu version of either pytorch or jax to avoid the cudnn conflict. If you want pytorch and jax with cuda support, I am afraid it is necessary to add the appropriate constraint on the build string.

ngam commented 1 month ago

I think you should get latest pytorch, and the last jax version (0.4.28) that supports the same cudnn version that pytorch supports.

I am afraid that this will not happen, as the solver is not aware that you want cuda versions of pytorch and jax, so it may install the latest cpu version of either pytorch or jax to avoid the cudnn conflict. If you want pytorch and jax with cuda support, I am afraid it is necessary to add the appropriate constraint on the build string.

@traversaro is correct. On the same machine, micromamba create -n test_jaxs jax pytorch resolves jax, jaxlib, and cudnn like above, but pytorch 2.4.0 as the cpu_mkl version, build 1.

Again, the two larger points here are:

  1. The fix here addressed the issue that was reported, namely we should've had proper constraints on cudnn 9. That all has been addressed to the best our ability and should work like expected from the point of view of jax and jaxlib.
  2. There's an outstanding issue elsewhere that led to weird env resolution in the scipy devstack repo (e.g., @traversaro points out deleting the lock file can potentially help). Nonetheless, please don't expect to get pytorch and jaxlib built with cudnn9 working right away (a PR is pending in the pytorch feedstock).
rgommers commented 1 month ago

Thanks for confirming @ngam!

ngam commented 1 month ago

Taking back some of the above comment. There was an issue in labeling the packages as broken. This is being fixed in https://github.com/conda-forge/admin-requests/pull/1065. Sorry about the inconvenience!