jax==0.4.31 has a check at runtime (that is not reflected on cuda metadata) on cudnn>=9, and so trying to use cuda on jax==0.4.31 result in an error like:
CUDA backend failed to initialize: Unable to use CUDA because of the following issues with CUDA components:
Outdated cuDNN installation found.
Version JAX was built against: 8907
Minimum supported: 9100
Installed version: 8907
As the only possible combinations of jaxlib and jax that can be installed together (given their mutual constrained dependencies) is:
jax==0.4.31 + jaxlib==0.4.31
jax==0.4.31 + jaxlib==0.4.30
I think it make sense to mark as broken the cuda builds of jaxlib==0.4.31 and jaxlib==0.4.30 as they have the constraint of installing cudnn 8.*, and they can't be used with jax==0.4.31.
This will ensure that users that install jaxlib==*=*cuda* to get a cuda-enabled jax actually get a working jax+cuda. This will continue to install a cpu-only version of jax if one just installs conda install jax, to actually fix that we need instead to rebuild jaxlib with cudnn==9, see https://github.com/conda-forge/conda-forge-pinning-feedstock/pull/6310 .
jax==0.4.31
has a check at runtime (that is not reflected on cuda metadata) on cudnn>=9, and so trying to usecuda
onjax==0.4.31
result in an error like:See https://github.com/conda-forge/jaxlib-feedstock/issues/277 and https://github.com/conda-forge/jax-feedstock/issues/149 .
As the only possible combinations of
jaxlib
andjax
that can be installed together (given their mutual constrained dependencies) is:I think it make sense to mark as broken the cuda builds of
jaxlib==0.4.31
andjaxlib==0.4.30
as they have the constraint of installingcudnn 8.*
, and they can't be used withjax==0.4.31
.This will ensure that users that install
jaxlib==*=*cuda*
to get a cuda-enabled jax actually get a working jax+cuda. This will continue to install a cpu-only version of jax if one just installsconda install jax
, to actually fix that we need instead to rebuild jaxlib with cudnn==9, see https://github.com/conda-forge/conda-forge-pinning-feedstock/pull/6310 .For more details, see the discussion in https://github.com/conda-forge/jax-feedstock/issues/149#issuecomment-2303924496 .
As we were discussing on this with @ngam, it would be great to have his feedback before merging.
ping @conda-forge/jaxlib @conda-forge/jax
Checklist: