I changed jax.ops.index* expressions to ndarray.at[idx]~ expressions, to be compatible with the latest JAX.
There are some minor modifications, like reordering module imports.
Also I added a flag -jax.config_update("jax_platform_name", "cpu")- which makes run the code automatically in cpu(s).
I changed jax.ops.index* expressions to ndarray.at[idx]~ expressions, to be compatible with the latest JAX. There are some minor modifications, like reordering module imports. Also I added a flag -jax.config_update("jax_platform_name", "cpu")- which makes run the code automatically in cpu(s).