Dependency mismatch: Outdated cuDNN installation makes CUDA backend fail to initialize #149

Closed sergiovaneg closed 1 week ago

sergiovaneg commented 1 month ago

Solution to issue cannot be found in the documentation.


Steps to reproduce

  1. Install the latest JAX and Keras versions with GPU support from conda-forge
  2. Instantiate a Keras model in Python making sure to use JAX as the backend
  3. Call the model (or use the fit/predict methods) so the execution graph gets compiled

Expected result

The model runs on GPU.

Actual result

The model falls back to running on CPU and throws the following error:

CUDA backend failed to initialize: Unable to use CUDA because of the following issues with CUDA components:
Outdated cuDNN installation found.
Version JAX was built against: 8907
Minimum supported: 9100
Installed version: 8907
The local installation version must be no lower than 9100..(Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

Visual Proof



Medium: the code still runs, and this can be temporarily remediated by pinning the JAX version as jax<=0.4.29

Suggested course of action

Pin cuDNN 9.1 as the minimum version for jax 0.4.31 and wait for the cuDNN conda-forge maintainer to update the module.

Installed packages

Environment info

active environment : keras_jax
    active env location : /home/sergiovaneg/miniforge3/envs/keras_jax
            shell level : 2
       user config file : /home/sergiovaneg/.condarc
 populated config files : /home/sergiovaneg/miniforge3/.condarc
          conda version : 24.7.1
    conda-build version : not installed
         python version :
                 solver : libmamba (default)
       virtual packages : __archspec=1=skylake
       base environment : /home/sergiovaneg/miniforge3  (writable)
      conda av data dir : /home/sergiovaneg/miniforge3/etc/conda
  conda av metadata url : None
           channel URLs :
          package cache : /home/sergiovaneg/miniforge3/pkgs
       envs directories : /home/sergiovaneg/miniforge3/envs
               platform : linux-64
             user-agent : conda/24.7.1 requests/2.32.3 CPython/3.12.3 Linux/ opensuse-tumbleweed/20240806 glibc/2.39 solver/libmamba conda-libmamba-solver/24.7.0 libmambapy/1.5.8
                UID:GID : 1000:1000
             netrc file : None
           offline mode : False
traversaro commented 1 month ago

Xref related issue: .

ngam commented 3 weeks ago

@traversaro please feel free to tag the conda-forge/jax and conda-forge/jaxlib teams in the future if you'd like quicker responses :)

I read both this issue and My understanding is that we never wanted to add cuda-related constraints in this jax feedstock, and I likely would've opted to merge even if I knew ahead of time about these issues. I am not sure what the best course of action is. What's your recommendation?

My preference is that we instrument a fix in the jaxlib feedstock (I am maintainer there too). I can also work with the jax team to move the dependency from jax to jaxlib so that it is much clearer this is a backend (jaxlib) dependency and not a wrapper (jax) dependency.


Now that the cudnn v9 PR is in, we should fix this and perhaps issue repodata patch soon

traversaro commented 3 weeks ago

@traversaro please feel free to tag the conda-forge/jax and conda-forge/jaxlib teams in the future if you'd like quicker responses :)

My bad, I kind of silently assumes maintainers watches their repo so I typically mentions teams only on related issue on different repos, but indeed a "Mentioned" notification typically get noticed more!

I read both this issue and My understanding is that we never wanted to add cuda-related constraints in this jax feedstock, and I likely would've opted to merge even if I knew ahead of time about these issues. I am not sure what the best course of action is. What's your recommendation?

My preference is that we instrument a fix in the jaxlib feedstock (I am maintainer there too).

The problem is that given that there is no strict version constraint between jaxlib and jax (if I am not wrong), so how can you translate the constraint that a given jax version has in the corresponding jaxlib version? At the moment the check is on the jax side, so the most natural way to capture it would be to have a run_constrained constrain on the cudnn version, what do you think? What is the rationale for not having any cuda constraint in jax, if there is literally a cuda constraint check in the jax code?

I can also work with the jax team to move the dependency from jax to jaxlib so that it is much clearer this is a backend (jaxlib) dependency and not a wrapper (jax) dependency.

I am not familiar enough with jax/jaxlib to provide an opinion on this, but if indeed the part that interacts with cudnn is jaxlib, it could make sense to have the check there.

Now that the cudnn v9 PR is in, we should fix this and perhaps issue repodata patch soon

If is merged soonish, we can wait for the migration PR, otherwise we could bump the cudnn version manually in a jaxlib PR.

For what regards the repodata, the only thing I think it could make sense to repodata patch is to add a run_constrain cudnn>=9.1.0 version in jax 0.4.30 and 0.4.31, that would act as a proxy for "jaxlib that has been built at least with cuda>=9.1.0, as anyhow jaxlib will have a run dependency on the given version of cudnn against which it has been compiled.

traversaro commented 3 weeks ago

At the moment the check is on the jax side, so the most natural way to capture it would be to have a run_constrained constrain on the cudnn version, what do you think? What is the rationale for not having any cuda constraint in jax, if there is literally a cuda constraint check in the jax code?

Ah, I think I understood the problem here. If you had a run_constrained in jax for cudnn, you can't install it with a different cudnn, even if you are just interested in the cpu functionality.

ngam commented 3 weeks ago

I think run_constrained is potentially a nice workaround. If with a given jax version cudnn is requested to be installed, it will ensure it obeys the the run_constrained condition. It won't impact cpu builds.

Btw, we do have two separate constraints between jaxlib and jax in conda-forge:

traversaro commented 3 weeks ago

I think run_constrained is potentially a nice workaround. If with a given jax version cudnn is requested to be installed, it will ensure it obeys the the run_constrained condition. It won't impact cpu builds.

The scenario I was imagining is that someone could have an environment with jax 0.4.31 with jaxlib with just cpu support, and let's say pytorch with cuda support and cudnn==8. If we had the run_exports this environment would stop working, that may not be intuitive. However, if we are able to fix jax + 0.4.31 + cuda this may be more important, so it is always a tradeoff.

traversaro commented 3 weeks ago

Btw, we do have two separate constraints between jaxlib and jax in conda-forge:

* in jax, we set a minimum (usually declared upstream) and more recently we started setting up a maximum that is equal to the jax version. These are _run_ dependencies, [link](

* in jaxlib, we set a run constraint on jax (which means, if jaxlib and jax are to be installed together, it must be obeyed), [link](

Cool, this makes the problem more clear, so the combination we need to worry that are affected by this are:

Due to the jax >= {{version}} run_constrained constrain in jaxlib, those jaxlib can't be installed with any other jax. So at this point the easiest thing is to mark those jaxlib cuda build (0.4.31 and 0.4.30) as broken? As they can only be used with jax==0.4.31, but jax requires cudnn >= 9.*, but the run constrain of those jaxlib constains cudnn >=,<9.0a0 ?

We do not even need to wait for the cudnn==9 builds to mark those builds as broken, as already marking those as broken at least fixes the jax+cuda installation for people that install the matchspec jaxlib==*=*cuda*. I can open a PR to start doing that, we can continue the discussion there.

traversaro commented 3 weeks ago

I can open a PR to start doing that, we can continue the discussion there.

Done in . This will ensure that users that install jaxlib==*=*cuda* to get a cuda-enabled jax actually get a working jax+cuda. This will continue to install a cpu-only version of jax if one just installs conda install jax, to actually fix that we need instead to rebuild jaxlib with cudnn==9, see .

ngam commented 1 week ago

@conda-forge-admin, please rerender feedstock

conda-forge-webservices[bot] commented 1 week ago

Hi! This is the friendly automated conda-forge-webservice.

I just wanted to let you know that I started rerendering the recipe in conda-forge/jax-feedstock#150.