Currently, JAX releases 0.4, which drops support for 3.8. So we need to bump version in CI to be able to test the latest version. In addition, additional treatment for jax array is needed because after jax 0.4, DeviceArray becomes jax.Array (i.e. jax.numpy.ndarray) and tracer is a subclass of jax Arrray.
Currently, JAX releases 0.4, which drops support for 3.8. So we need to bump version in CI to be able to test the latest version. In addition, additional treatment for jax array is needed because after jax 0.4, DeviceArray becomes jax.Array (i.e. jax.numpy.ndarray) and tracer is a subclass of jax Arrray.