Because of the deprecation of from jax.experimental.host_callback import id_tap, replace jax.experimental.host_callback with jax.pure_callback.
Also support jax==0.4.28:
The kind argument to jax.numpy.sort and jax.numpy.argsort is now removed. Use stable=True or stable=False instead.
Because of the deprecation of from jax.experimental.host_callback import id_tap, replace jax.experimental.host_callback with jax.pure_callback. Also support jax==0.4.28: