google-deepmind / acme

A library of reinforcement learning components and agents
Apache License 2.0
3.52k stars 426 forks source link

Unsolvable jax dependency #321

Open adiaconu11 opened 6 months ago

adiaconu11 commented 6 months ago

Following the most recent commit, the jax.tree module is now being used instead of jax.tree_*. This is change requires jax >= 0.4.25. However, the are still many parts of the repository that are still old/deprecated. For instance, if you install 0.4.25 you might get something like:

AttributeError: module 'jax.random' has no attribute 'KeyArray'

This is because this module has been removed in jax 0.4.24, meaning that in order to not run into this problem you need jax <=0.4.23. Obviously this goes against the requirement above.

Lastly, there is still the issue with DeviceArray and ShardedDeviceArray. They have all been changed to somply jax.Array back in jax=0.4.0! At the current state of the repo you basically need to add lines like:

jax.interpreters.xla.DeviceArray = jax.Array in order to be able to even import acme...

adiaconu11 commented 5 months ago

the jax.Array issue can be solved by installing chex==0.1.7 instead of the default 0.1.6. This should be updated in the requirements of acme.

lkoelman commented 5 months ago

Bumping into similar problems. It is very hard to install a functional stack.