conda-forge / admin-requests

27 stars 283 forks source link

Mark cuda builds of jaxlib 0.4.30 and 0.4.31 as broken due to jax runtime check on cudnn version #1050

Closed traversaro closed 1 month ago

traversaro commented 1 month ago

jax==0.4.31 has a check at runtime (that is not reflected on cuda metadata) on cudnn>=9, and so trying to use cuda on jax==0.4.31 result in an error like:

CUDA backend failed to initialize: Unable to use CUDA because of the following issues with CUDA components:
Outdated cuDNN installation found.
Version JAX was built against: 8907
Minimum supported: 9100
Installed version: 8907

See https://github.com/conda-forge/jaxlib-feedstock/issues/277 and https://github.com/conda-forge/jax-feedstock/issues/149 .

As the only possible combinations of jaxlib and jax that can be installed together (given their mutual constrained dependencies) is:

I think it make sense to mark as broken the cuda builds of jaxlib==0.4.31 and jaxlib==0.4.30 as they have the constraint of installing cudnn 8.*, and they can't be used with jax==0.4.31.

This will ensure that users that install jaxlib==*=*cuda* to get a cuda-enabled jax actually get a working jax+cuda. This will continue to install a cpu-only version of jax if one just installs conda install jax, to actually fix that we need instead to rebuild jaxlib with cudnn==9, see https://github.com/conda-forge/conda-forge-pinning-feedstock/pull/6310 .

For more details, see the discussion in https://github.com/conda-forge/jax-feedstock/issues/149#issuecomment-2303924496 .

As we were discussing on this with @ngam, it would be great to have his feedback before merging.

ping @conda-forge/jaxlib @conda-forge/jax

Checklist: