Closed cemlyn007 closed 1 year ago
Hey! When I try remove jaxlib
, I get conflicts with jax
, see e.g. PR #158. Would you have any clues on how to fix this?
Hi @aar65537! When working on your recent PR (#160), have you tried removing jaxlib
?
Good evening, it might be that one of Jumanji's dependencies unfortunately specifies jaxlib
The issue is that jax only lists jaxlib as a dependency when you include one of the platform extras. So pip install jax[platform]
will install jaxlib, but pip install jax
won't. In this case, the wrong version of jaxlib is still installed by chex, which causes the error. I think you could fix the error by changing the jax constraint to jax[cpu]>=0.2.26,<=0.4.10
.
Thanks @aar65537, so to fix the CI you just need to do this:
pip install .[dev,train] jax[cpu]
I don't think you should change the requirements to specify a jax extra.
Resolved in #174
You do not need to explicitly state the version of jaxlib because
pip install jax[<device>]
will automatically install the correctjaxlib
. This is the culprit line: https://github.com/instadeepai/jumanji/blob/b16cf5dcde88a73b0c8f56f93b97813105cb99ec/requirements/requirements.txt#L5