Open adiaconu11 opened 6 months ago
the jax.Array issue can be solved by installing chex==0.1.7 instead of the default 0.1.6. This should be updated in the requirements of acme.
Bumping into similar problems. It is very hard to install a functional stack.
Following the most recent commit, the jax.tree module is now being used instead of jax.tree_*. This is change requires jax >= 0.4.25. However, the are still many parts of the repository that are still old/deprecated. For instance, if you install 0.4.25 you might get something like:
AttributeError: module 'jax.random' has no attribute 'KeyArray'
This is because this module has been removed in jax 0.4.24, meaning that in order to not run into this problem you need jax <=0.4.23. Obviously this goes against the requirement above.
Lastly, there is still the issue with DeviceArray and ShardedDeviceArray. They have all been changed to somply jax.Array back in jax=0.4.0! At the current state of the repo you basically need to add lines like:
jax.interpreters.xla.DeviceArray = jax.Array
in order to be able to even import acme...