QuantEcon / lecture-jax

Lectures on Quantitative Economics Using JAX
https://jax.quantecon.org/
28 stars 4 forks source link

In place operations in JAX #159

Open jstac opened 3 months ago

jstac commented 3 months ago

@mmcky

Can you please search for this line:

In line with immutability, JAX does not support inplace operations:

It looks like JAX is sorting in place now. Could you please evaluate and propose a fix?

kp992 commented 3 months ago

https://github.com/QuantEcon/lecture-jax/blob/1a1be32fd03b11a4bfd960565818cddfa2070176/lectures/jax_intro.md?plain=1#L161 Its here.

kp992 commented 3 months ago

For me, its still not inplace:

>>> import jax
>>> jax.__version__
'0.4.23'
>>> import jax.numpy as jnp
>>> a = jnp.array((2, 1))
>>> a.sort()
Array([1, 2], dtype=int32)
>>> a
Array([2, 1], dtype=int32)