[JAX] Replace uses of deprecated jax.ops.index_update(x, idx, y) APIs with their up-to-date, more succinct equivalent x.at[idx].set(y).
The JAX operators:
jax.ops.index_update(x, jax.ops.index[idx], y)
jax.ops.index_add(x, jax.ops.index[idx], y)
...
have long been deprecated in lieu of their more succinct counterparts:
x.at[idx].set(y)
x.at[idx].add(y)
...
This change updates users of the deprecated APIs to use the current APIs, in preparation for removing the deprecated forms from JAX.
The main subtlety is that if x is not a JAX array, we must cast it to one using jnp.asarray(x) before using the new form, since .at[...] is only defined on JAX arrays.
[JAX] Replace uses of deprecated
jax.ops.index_update(x, idx, y)
APIs with their up-to-date, more succinct equivalentx.at[idx].set(y)
.The JAX operators: jax.ops.index_update(x, jax.ops.index[idx], y) jax.ops.index_add(x, jax.ops.index[idx], y) ...
have long been deprecated in lieu of their more succinct counterparts: x.at[idx].set(y) x.at[idx].add(y) ...
This change updates users of the deprecated APIs to use the current APIs, in preparation for removing the deprecated forms from JAX.
The main subtlety is that if
x
is not a JAX array, we must cast it to one usingjnp.asarray(x)
before using the new form, since.at[...]
is only defined on JAX arrays.