adityab / CrossQ

Official code release for "CrossQ: Batch Normalization in Deep Reinforcement Learning for Greater Sample Efficiency and Simplicity"
http://aditya.bhatts.org/CrossQ
Other
57 stars 4 forks source link

JAX version and installation instructions #1

Closed JankowskiChristopher closed 7 months ago

JankowskiChristopher commented 7 months ago

Hello, I am trying to run your code, but there seems to be an issue with installation. The environment file specifies that JAX version is 0.4.19, however last line in README:

python -m pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

upgrades JAX version to 0.4.25 and causes many bugs in code e.g. module 'jax.random' has no attribute 'KeyArray' (but this is the tip of the iceberg, because deleting KeyArray from typing unveils more errors).

Do I correctly assume that this line should not be present in README and the code should work with JAX 0.4.19? I was able to run the code with JAX 0.4.19. Could you also check the installation instruction in README that without this last line it works correctly? I probably faced more issues and had to install some packages manually as these lines did not install it (sorry cannot recall now what was exactly the issue, as I cannot reproduce it now due to many steps taken when trying to run your code).

adityab commented 7 months ago

Yes, 0.4.19 is the correct JAX version. I've just now:

Thanks a lot for pointing this out!