blackjax-devs / blackjax

BlackJAX is a Bayesian Inference library designed for ease of use, speed and modularity.
https://blackjax-devs.github.io/blackjax/
Apache License 2.0
826 stars 106 forks source link

Potential Performance due to Jax version #746

Open junpenglao opened 1 month ago

junpenglao commented 1 month ago

Describe the issue as clearly as possible:

Our benchmark runtime increased more than 2x after JAX version upgrade to 0.4.34

Reproduced locally: On JAX 0.4.30

-------------------------------------------------------------------------------- benchmark: 2 tests -------------------------------------------------------------------------------
Name (time in s)            Min               Max              Mean            StdDev            Median               IQR            Outliers     OPS            Rounds  Iterations
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_regression_nuts     4.2739 (1.0)      5.0262 (1.0)      4.7685 (1.0)      0.3398 (1.0)      4.9671 (1.0)      0.5408 (1.0)           1;0  0.2097 (1.0)           5           1
test_regression_hmc      7.2055 (1.69)     8.1514 (1.62)     7.6479 (1.60)     0.4128 (1.22)     7.5257 (1.52)     0.7291 (1.35)          2;0  0.1308 (0.62)          5           1
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

On JAX 0.4.34

---------------------------------------------------------------------------------- benchmark: 2 tests ---------------------------------------------------------------------------------
Name (time in s)             Min                Max               Mean            StdDev             Median               IQR            Outliers     OPS            Rounds  Iterations
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_regression_nuts      9.2754 (1.0)      10.2643 (1.0)       9.6681 (1.0)      0.3660 (1.0)       9.6078 (1.0)      0.3647 (1.0)           2;0  0.1034 (1.0)           5           1
test_regression_hmc      19.7752 (2.13)     21.4303 (2.09)     20.6382 (2.13)     0.7185 (1.96)     20.4793 (2.13)     1.2633 (3.46)          2;0  0.0485 (0.47)          5           1
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

Steps/code to reproduce the bug:

Ping to a jax version and run 

pytest --benchmark-only

### Expected result:

```shell
n.a

Error message:

n.a

Blackjax/JAX/jaxlib/Python version information:

n.a

Context for the issue:

No response

junpenglao commented 3 weeks ago

related: https://github.com/pyro-ppl/numpyro/issues/1867 likely rootcause and workaround see: https://github.com/jax-ml/jax/discussions/23822

ColCarroll commented 1 week ago

Also https://github.com/jax-ml/jax/discussions/24501

On Mon, Oct 14, 2024, 8:23 AM Junpeng Lao @.***> wrote:

related: pyro-ppl/numpyro#1867 https://github.com/pyro-ppl/numpyro/issues/1867 likely rootcause and workaround see: jax-ml/jax#23822 https://github.com/jax-ml/jax/discussions/23822

— Reply to this email directly, view it on GitHub https://github.com/blackjax-devs/blackjax/issues/746#issuecomment-2411077970, or unsubscribe https://github.com/notifications/unsubscribe-auth/AARQOEGWUIUCYRFOP7EMFCTZ3OZSNAVCNFSM6AAAAABPN7CCX2VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDIMJRGA3TOOJXGA . You are receiving this because you are subscribed to this thread.Message ID: @.***>