Open junpenglao opened 1 month ago
related: https://github.com/pyro-ppl/numpyro/issues/1867 likely rootcause and workaround see: https://github.com/jax-ml/jax/discussions/23822
Also https://github.com/jax-ml/jax/discussions/24501
On Mon, Oct 14, 2024, 8:23 AM Junpeng Lao @.***> wrote:
related: pyro-ppl/numpyro#1867 https://github.com/pyro-ppl/numpyro/issues/1867 likely rootcause and workaround see: jax-ml/jax#23822 https://github.com/jax-ml/jax/discussions/23822
— Reply to this email directly, view it on GitHub https://github.com/blackjax-devs/blackjax/issues/746#issuecomment-2411077970, or unsubscribe https://github.com/notifications/unsubscribe-auth/AARQOEGWUIUCYRFOP7EMFCTZ3OZSNAVCNFSM6AAAAABPN7CCX2VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDIMJRGA3TOOJXGA . You are receiving this because you are subscribed to this thread.Message ID: @.***>
Describe the issue as clearly as possible:
Our benchmark runtime increased more than 2x after JAX version upgrade to 0.4.34
Reproduced locally: On JAX 0.4.30
On JAX 0.4.34
Steps/code to reproduce the bug:
Error message:
Blackjax/JAX/jaxlib/Python version information:
Context for the issue:
No response