pyro-ppl / funsor

Functional tensors for probabilistic programming
https://funsor.pyro.ai
Apache License 2.0
236 stars 20 forks source link

Add fixes for jax 0.4 #600

Closed fehiepsi closed 1 year ago

fehiepsi commented 1 year ago

Currently, JAX releases 0.4, which drops support for 3.8. So we need to bump version in CI to be able to test the latest version. In addition, additional treatment for jax array is needed because after jax 0.4, DeviceArray becomes jax.Array (i.e. jax.numpy.ndarray) and tracer is a subclass of jax Arrray.