Open jstac opened 3 months ago
@mmcky
Can you please search for this line:
In line with immutability, JAX does not support inplace operations:
It looks like JAX is sorting in place now. Could you please evaluate and propose a fix?
https://github.com/QuantEcon/lecture-jax/blob/1a1be32fd03b11a4bfd960565818cddfa2070176/lectures/jax_intro.md?plain=1#L161 Its here.
For me, its still not inplace:
>>> import jax >>> jax.__version__ '0.4.23' >>> import jax.numpy as jnp >>> a = jnp.array((2, 1)) >>> a.sort() Array([1, 2], dtype=int32) >>> a Array([2, 1], dtype=int32)
@mmcky
Can you please search for this line:
In line with immutability, JAX does not support inplace operations:
It looks like JAX is sorting in place now. Could you please evaluate and propose a fix?