cgarciae / treex

A Pytree Module system for Deep Learning in JAX
https://cgarciae.github.io/treex/
MIT License
215 stars 17 forks source link

Update select_topk to not use deprecated function #43

Closed ptigwe closed 2 years ago

ptigwe commented 2 years ago

Replacing the deprecated jax.ops.index_update with the suggested alternative of arr.at[idx].set(val). Another alternative which is to use masking tricks also yields the same effect is as follows:

idx_axis0 = jnp.arange(prob_tensor.shape[0])
jnp.sum(idx_axis0 == jnp.expand_dims(idx_axis1, -1), 1)
cgarciae commented 2 years ago

LGTM! Thanks a lot @ptigwe!