hanrach / p2d_fast_solver

14 stars 3 forks source link

replace jax.ops.index~ to ndarray.at[idx]~ #13

Closed jaeminoh closed 8 months ago

jaeminoh commented 11 months ago

I changed jax.ops.index* expressions to ndarray.at[idx]~ expressions, to be compatible with the latest JAX. There are some minor modifications, like reordering module imports. Also I added a flag -jax.config_update("jax_platform_name", "cpu")- which makes run the code automatically in cpu(s).