google / trax

Trax — Deep Learning with Clear Code and Speed
Apache License 2.0
8.07k stars 813 forks source link

[JAX] Replace uses of deprecated `jax.ops.index_update(x, idx, y)` APIs with their up-to-date, more succinct equivalent `x.at[idx].set(y)`. #1701

Closed copybara-service[bot] closed 3 years ago

copybara-service[bot] commented 3 years ago

[JAX] Replace uses of deprecated jax.ops.index_update(x, idx, y) APIs with their up-to-date, more succinct equivalent x.at[idx].set(y).

The JAX operators: jax.ops.index_update(x, jax.ops.index[idx], y) jax.ops.index_add(x, jax.ops.index[idx], y) ...

have long been deprecated in lieu of their more succinct counterparts: x.at[idx].set(y) x.at[idx].add(y) ...

This change updates users of the deprecated APIs to use the current APIs, in preparation for removing the deprecated forms from JAX.

The main subtlety is that if x is not a JAX array, we must cast it to one using jnp.asarray(x) before using the new form, since .at[...] is only defined on JAX arrays.