blackjax-devs / blackjax

BlackJAX is a Bayesian Inference library designed for ease of use, speed and modularity.
https://blackjax-devs.github.io/blackjax/
Apache License 2.0
806 stars 105 forks source link

test_chees_adaptation fail with jax 0.4.26 #662

Closed GaetanLepage closed 5 months ago

GaetanLepage commented 5 months ago

Describe the issue as clearly as possible:

When updating jax to the latest version (0.4.26), the following test starts to fail:

FAILED tests/adaptation/test_adaptation.py::test_chees_adaptation - AssertionError: 

Steps/code to reproduce the bug:

pytest

Expected result:

All tests pass

Error message:

warmup = blackjax.chees_adaptation(
            logprob_fn, num_chains=num_chains, target_acceptance_rate=0.75
        )

        initial_positions = jax.random.normal(init_key, (num_chains, 2))
        (last_states, parameters), warmup_info = warmup.run(
            warmup_key,
            initial_positions,
            step_size=step_size,
            optim=optax.adamw(learning_rate=0.5),
            num_steps=num_burnin_steps,
        )
        algorithm = blackjax.dynamic_hmc(logprob_fn, **parameters)

        chain_keys = jax.random.split(inference_key, num_chains)
        _, _, infos = jax.vmap(
            lambda key, state: run_inference_algorithm(key, state, algorithm, num_results)
        )(chain_keys, last_states)

        harmonic_mean = 1.0 / jnp.mean(1.0 / infos.acceptance_rate)
        np.testing.assert_allclose(harmonic_mean, 0.75, rtol=1e-1)
        np.testing.assert_allclose(parameters["step_size"], 1.5, rtol=2e-1)
>       np.testing.assert_allclose(infos.num_integration_steps.mean(), 15.0, rtol=3e-1)

tests/adaptation/test_adaptation.py:71: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

args = (<function assert_allclose.<locals>.compare at 0x7ffcd9de7ec0>, array(4.4220004, dtype=float32), array(15.))
kwds = {'equal_nan': True, 'err_msg': '', 'header': 'Not equal to tolerance rtol=0.3, atol=0', 'verbose': True}

    @wraps(func)
    def inner(*args, **kwds):
        with self._recreate_cm():
>           return func(*args, **kwds)
E           AssertionError: 
E           Not equal to tolerance rtol=0.3, atol=0
E           
E           Mismatched elements: 1 / 1 (100%)
E           Max absolute difference: 10.57799959
E           Max relative difference: 0.70519997
E            x: array(4.422, dtype=float32)
E            y: array(15.)

/nix/store/lpi16513bai8kg2bd841745vzk72475x-python3-3.11.9/lib/python3.11/contextlib.py:81: AssertionError
=========================== short test summary info ============================
FAILED tests/adaptation/test_adaptation.py::test_chees_adaptation - AssertionError: 
============= 1 failed, 533 passed, 1 skipped in 131.42s (0:02:11) =============

Blackjax/JAX/jaxlib/Python version information:

BlackJAX 1.1.1
Python 3.11.9 (main, Apr  2 2024, 08:25:04) [GCC 13.2.0]
Jax 0.4.26
Jaxlib 0.4.26

Context for the issue:

Updating jax in nixpkgs: https://github.com/NixOS/nixpkgs/pull/291705

junpenglao commented 5 months ago

Released https://github.com/blackjax-devs/blackjax/releases/tag/1.2.0, it should work now.

GaetanLepage commented 5 months ago

Released https://github.com/blackjax-devs/blackjax/releases/tag/1.2.0, it should work now.

Indeed it does ! Thanks for the very quick response.