instadeepai / jumanji

🕹️ A diverse suite of scalable reinforcement learning environments in JAX
https://instadeepai.github.io/jumanji
Apache License 2.0
584 stars 71 forks source link

build: remove `jaxlib` #148

Closed cemlyn007 closed 1 year ago

cemlyn007 commented 1 year ago

You do not need to explicitly state the version of jaxlib because pip install jax[<device>] will automatically install the correct jaxlib. This is the culprit line: https://github.com/instadeepai/jumanji/blob/b16cf5dcde88a73b0c8f56f93b97813105cb99ec/requirements/requirements.txt#L5

clement-bonnet commented 1 year ago

Hey! When I try remove jaxlib, I get conflicts with jax, see e.g. PR #158. Would you have any clues on how to fix this?

clement-bonnet commented 1 year ago

Hi @aar65537! When working on your recent PR (#160), have you tried removing jaxlib?

cemlyn007 commented 1 year ago

Good evening, it might be that one of Jumanji's dependencies unfortunately specifies jaxlib

aar65537 commented 1 year ago

The issue is that jax only lists jaxlib as a dependency when you include one of the platform extras. So pip install jax[platform] will install jaxlib, but pip install jax won't. In this case, the wrong version of jaxlib is still installed by chex, which causes the error. I think you could fix the error by changing the jax constraint to jax[cpu]>=0.2.26,<=0.4.10.

cemlyn007 commented 1 year ago

Thanks @aar65537, so to fix the CI you just need to do this: pip install .[dev,train] jax[cpu] I don't think you should change the requirements to specify a jax extra.

clement-bonnet commented 1 year ago

Resolved in #174