Hello,
I am trying to run your code, but there seems to be an issue with installation. The environment file specifies that JAX version is 0.4.19, however last line in README:
upgrades JAX version to 0.4.25 and causes many bugs in code e.g. module 'jax.random' has no attribute 'KeyArray' (but this is the tip of the iceberg, because deleting KeyArray from typing unveils more errors).
Do I correctly assume that this line should not be present in README and the code should work with JAX 0.4.19?
I was able to run the code with JAX 0.4.19.
Could you also check the installation instruction in README that without this last line it works correctly? I probably faced more issues and had to install some packages manually as these lines did not install it (sorry cannot recall now what was exactly the issue, as I cannot reproduce it now due to many steps taken when trying to run your code).
Hello, I am trying to run your code, but there seems to be an issue with installation. The environment file specifies that JAX version is 0.4.19, however last line in README:
upgrades JAX version to 0.4.25 and causes many bugs in code e.g.
module 'jax.random' has no attribute 'KeyArray'
(but this is the tip of the iceberg, because deleting KeyArray from typing unveils more errors).Do I correctly assume that this line should not be present in README and the code should work with JAX 0.4.19? I was able to run the code with JAX 0.4.19. Could you also check the installation instruction in README that without this last line it works correctly? I probably faced more issues and had to install some packages manually as these lines did not install it (sorry cannot recall now what was exactly the issue, as I cannot reproduce it now due to many steps taken when trying to run your code).