pyro-ppl / numpyro

Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU.
https://num.pyro.ai
Apache License 2.0
2.15k stars 234 forks source link

Update jax.tree_util.tree_map to jax.tree.map #1821

Closed fehiepsi closed 3 months ago

fehiepsi commented 3 months ago

Since jax 0.4.25 (released Feb 24), jax.tree_util.tree_map is deprecated, in favor of jax.tree.map. This PR updates numpyro to use the new pattern.

In addition, jnp.clip(..., a_min=..., a_max=...) is deprecated. I change the pattern to jnp.clip(..., ..., ...) to remove the deprecation warning in the tests.

review-notebook-app[bot] commented 3 months ago

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

fehiepsi commented 3 months ago

The current numpy versions still use a_min a_max. They have a PR to deprecate those arguments but it has not been merged yet.