probml / JSL

Jax SSM Library
MIT License
51 stars 13 forks source link

replace jax.ops.index_update with jax.numpy.at in all JSL code #25

Closed murphyk closed 2 years ago

murphyk commented 2 years ago

jax.ops.index_update has been removed (see https://github.com/google/jax/commit/f51a05a889f2fcb19946352b9d65f2b6c49fec4a) which breaks some JSL code. Please use jnp.at() instead.

murphyk commented 2 years ago

Example of the problem:

Screen Shot 2022-03-23 at 12 17 38 PM
karalleyna commented 2 years ago

Done.