stanford-crfm / haliax

Named Tensors for Legible Deep Learning in JAX
Apache License 2.0
149 stars 11 forks source link

In-place Array Modificatons #36

Closed bmac3 closed 1 year ago

bmac3 commented 1 year ago

Hello! Thank you so much for this library, it has quickly become my favorite way to use jax.

I was wondering what the recommended way to do array modifications is, similar to the jax x.at[...].set(...) syntax (https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html). I had a look through the documentation but couldn't find anything that exactly fits this use case.

bmac3 commented 1 year ago

Oops just realized that there is a updated_slice that pretty much suits my use case. The only thing that looks nice about using the .at syntax is that in the docs it specifies:

inside a jit() compiled function, expressions like x = x.at[idx].set(y) are guaranteed to be applied in-place.

but after doing a few experiments it looks like jax.lax.dynamic_update_slice might do the same thing. I'll close this issue now.