Closed sergiovaneg closed 1 week ago
Xref related issue: https://github.com/conda-forge/jaxlib-feedstock/issues/277 .
@traversaro please feel free to tag the conda-forge/jax and conda-forge/jaxlib teams in the future if you'd like quicker responses :)
I read both this issue and https://github.com/conda-forge/jaxlib-feedstock/issues/277. My understanding is that we never wanted to add cuda-related constraints in this jax feedstock, and I likely would've opted to merge https://github.com/conda-forge/jax-feedstock/pull/148 even if I knew ahead of time about these issues. I am not sure what the best course of action is. What's your recommendation?
My preference is that we instrument a fix in the jaxlib feedstock (I am maintainer there too). I can also work with the jax team to move the dependency from jax to jaxlib so that it is much clearer this is a backend (jaxlib) dependency and not a wrapper (jax) dependency.
Thoughts?
Now that the cudnn v9 PR is in, we should fix this and perhaps issue repodata patch soon
@traversaro please feel free to tag the conda-forge/jax and conda-forge/jaxlib teams in the future if you'd like quicker responses :)
My bad, I kind of silently assumes maintainers watches their repo so I typically mentions teams only on related issue on different repos, but indeed a "Mentioned" notification typically get noticed more!
I read both this issue and https://github.com/conda-forge/jaxlib-feedstock/issues/277. My understanding is that we never wanted to add cuda-related constraints in this jax feedstock, and I likely would've opted to merge https://github.com/conda-forge/jax-feedstock/pull/148 even if I knew ahead of time about these issues. I am not sure what the best course of action is. What's your recommendation?
My preference is that we instrument a fix in the jaxlib feedstock (I am maintainer there too).
The problem is that given that there is no strict version constraint between jaxlib and jax (if I am not wrong), so how can you translate the constraint that a given jax version has in the corresponding jaxlib version? At the moment the check is on the jax side, so the most natural way to capture it would be to have a run_constrained
constrain on the cudnn version, what do you think? What is the rationale for not having any cuda constraint in jax
, if there is literally a cuda constraint check in the jax code?
I can also work with the jax team to move the dependency from jax to jaxlib so that it is much clearer this is a backend (jaxlib) dependency and not a wrapper (jax) dependency.
I am not familiar enough with jax/jaxlib to provide an opinion on this, but if indeed the part that interacts with cudnn is jaxlib, it could make sense to have the check there.
Now that the cudnn v9 PR is in, we should fix this and perhaps issue repodata patch soon
If https://github.com/conda-forge/conda-forge-pinning-feedstock/pull/6310 is merged soonish, we can wait for the migration PR, otherwise we could bump the cudnn version manually in a jaxlib PR.
For what regards the repodata, the only thing I think it could make sense to repodata patch is to add a run_constrain
cudnn>=9.1.0 version in jax 0.4.30 and 0.4.31, that would act as a proxy for "jaxlib that has been built at least with cuda>=9.1.0, as anyhow jaxlib will have a run dependency on the given version of cudnn against which it has been compiled.
At the moment the check is on the jax side, so the most natural way to capture it would be to have a
run_constrained
constrain on the cudnn version, what do you think? What is the rationale for not having any cuda constraint injax
, if there is literally a cuda constraint check in the jax code?
Ah, I think I understood the problem here. If you had a run_constrained
in jax
for cudnn
, you can't install it with a different cudnn, even if you are just interested in the cpu functionality.
I think run_constrained is potentially a nice workaround. If with a given jax version cudnn is requested to be installed, it will ensure it obeys the the run_constrained condition. It won't impact cpu builds.
Btw, we do have two separate constraints between jaxlib and jax in conda-forge:
I think run_constrained is potentially a nice workaround. If with a given jax version cudnn is requested to be installed, it will ensure it obeys the the run_constrained condition. It won't impact cpu builds.
The scenario I was imagining is that someone could have an environment with jax 0.4.31 with jaxlib with just cpu support, and let's say pytorch with cuda support and cudnn==8. If we had the run_exports this environment would stop working, that may not be intuitive. However, if we are able to fix jax + 0.4.31 + cuda this may be more important, so it is always a tradeoff.
Btw, we do have two separate constraints between jaxlib and jax in conda-forge:
* in jax, we set a minimum (usually declared upstream) and more recently we started setting up a maximum that is equal to the jax version. These are _run_ dependencies, [link](https://github.com/conda-forge/jax-feedstock/blob/0228c9aeac76e8fa0193e818255a49c846aa25ba/recipe/meta.yaml#L32) * in jaxlib, we set a run constraint on jax (which means, if jaxlib and jax are to be installed together, it must be obeyed), [link](https://github.com/conda-forge/jaxlib-feedstock/blob/3ddae8944376b1f0ececddac6d1298b1e790562d/recipe/meta.yaml#L89)
Cool, this makes the problem more clear, so the combination we need to worry that are affected by this are:
Due to the jax >= {{version}}
run_constrained constrain in jaxlib, those jaxlib can't be installed with any other jax. So at this point the easiest thing is to mark those jaxlib cuda build (0.4.31 and 0.4.30) as broken? As they can only be used with jax==0.4.31, but jax requires cudnn >= 9.*, but the run constrain of those jaxlib constains cudnn >=8.9.7.29,<9.0a0
?
We do not even need to wait for the cudnn==9 builds to mark those builds as broken, as already marking those as broken at least fixes the jax+cuda installation for people that install the matchspec jaxlib==*=*cuda*
. I can open a PR to start doing that, we can continue the discussion there.
I can open a PR to start doing that, we can continue the discussion there.
Done in https://github.com/conda-forge/admin-requests/pull/1050 . This will ensure that users that install jaxlib==*=*cuda*
to get a cuda-enabled jax actually get a working jax+cuda. This will continue to install a cpu-only version of jax if one just installs conda install jax
, to actually fix that we need instead to rebuild jaxlib with cudnn==9, see https://github.com/conda-forge/conda-forge-pinning-feedstock/pull/6310 .
@conda-forge-admin, please rerender feedstock
Hi! This is the friendly automated conda-forge-webservice.
I just wanted to let you know that I started rerendering the recipe in conda-forge/jax-feedstock#150.
Solution to issue cannot be found in the documentation.
Issue
Steps to reproduce
Expected result
The model runs on GPU.
Actual result
The model falls back to running on CPU and throws the following error:
Visual Proof
Severity
Medium: the code still runs, and this can be temporarily remediated by pinning the JAX version as
jax<=0.4.29
Suggested course of action
Pin cuDNN 9.1 as the minimum version for jax 0.4.31 and wait for the cuDNN conda-forge maintainer to update the module.
Installed packages
Environment info