blackjax-devs / blackjax

BlackJAX is a Bayesian Inference library designed for ease of use, speed and modularity.
https://blackjax-devs.github.io/blackjax/
Apache License 2.0
806 stars 105 forks source link

Numerical test `test_chees_adaptation` fails on `aarch64-linux` #668

Open GaetanLepage opened 4 months ago

GaetanLepage commented 4 months ago

Describe the issue as clearly as possible:

Using jax/jaxlib 0.4.28 and jaxopt 0.8.3, the following test fails on aarch64-linux:

FAILED tests/adaptation/test_adaptation.py::test_chees_adaptation - AssertionError:

Steps/code to reproduce the bug:

pytest

Expected result:

Tests pass.

Error message:

=================================== FAILURES ===================================
____________________________ test_chees_adaptation _____________________________
[gw1] linux -- Python 3.11.9 /nix/store/33752yykc8r75jxvpcvpcynm22il4ch7-python3-3.11.9/bin/python3.11

    def test_chees_adaptation():
        logprob_fn = lambda x: jax.scipy.stats.norm.logpdf(
            x, loc=0.0, scale=jnp.array([1.0, 10.0])
        ).sum()

        num_burnin_steps = 1000
        num_results = 500
        num_chains = 16
        step_size = 0.1

        init_key, warmup_key, inference_key = jax.random.split(jax.random.key(346), 3)

        warmup = blackjax.chees_adaptation(
            logprob_fn, num_chains=num_chains, target_acceptance_rate=0.75
        )

        initial_positions = jax.random.normal(init_key, (num_chains, 2))
        (last_states, parameters), warmup_info = warmup.run(
            warmup_key,
            initial_positions,
            step_size=step_size,
            optim=optax.adamw(learning_rate=0.5),
            num_steps=num_burnin_steps,
        )
        algorithm = blackjax.dynamic_hmc(logprob_fn, **parameters)

        chain_keys = jax.random.split(inference_key, num_chains)
        _, _, infos = jax.vmap(
            lambda key, state: run_inference_algorithm(key, state, algorithm, num_results)
        )(chain_keys, last_states)

        harmonic_mean = 1.0 / jnp.mean(1.0 / infos.acceptance_rate)
>       np.testing.assert_allclose(harmonic_mean, 0.75, rtol=1e-1)

tests/adaptation/test_adaptation.py:69: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

args = (<function assert_allclose.<locals>.compare at 0xfffee0976ca0>, array(0.6619941, dtype=float32), array(0.75))
kwds = {'equal_nan': True, 'err_msg': '', 'header': 'Not equal to tolerance rtol=0.1, atol=0', 'verbose': True}

    @wraps(func)
    def inner(*args, **kwds):
        with self._recreate_cm():
>           return func(*args, **kwds)
E           AssertionError: 
E           Not equal to tolerance rtol=0.1, atol=0
E           
E           Mismatched elements: 1 / 1 (100%)
E           Max absolute difference: 0.0880059
E           Max relative difference: 0.1173412
E            x: array(0.661994, dtype=float32)
E            y: array(0.75)

/nix/store/33752yykc8r75jxvpcvpcynm22il4ch7-python3-3.11.9/lib/python3.11/contextlib.py:81: AssertionError
=========================== short test summary info ============================
FAILED tests/adaptation/test_adaptation.py::test_chees_adaptation - AssertionError: 
============= 1 failed, 442 passed, 1 skipped in 139.54s (0:02:19) =============

Blackjax/JAX/jaxlib/Python version information:

BlackJAX 1.2.1
Python 3.11.9 (main, Apr  2 2024, 08:25:04) [GCC 13.2.0]
Jax 0.4.28
Jaxlib 0.4.28

Context for the issue:

No response

junpenglao commented 4 months ago

@albcab could you take a look to make this test more robust?

albcab commented 4 months ago

I can't reproduce the failing test on my machine, so it's hard to debug. Making atol=1e-2 and rtol=0 would obviously fix it. Since the harmonic mean is random, increasing chains/steps/learning rates might not necessarily avoid the failing test.

Just for the sake of passing the test in aarch64-linux I would change rtol to atol, @junpenglao thoughts?

junpenglao commented 4 months ago

Yeah sure.

GaetanLepage commented 4 months ago

I can already confirm that the test is fixed thanks to @albcab's patch. Maybe we can wait until the next release to mark this issue as closed.

GaetanLepage commented 4 days ago

Things might have changed since then, because several tests fail on 1.2.4. Skipping test_chees_adaptation is enough to make the test suite succeed.

junpenglao commented 3 days ago

Thanks for the feedback, it is a bit difficult for us to replicate, could you paste the full trace for all the test fail? There is not much change on blackjax side re Chees, i suspect it is something upstream in JAX is causing this.

GaetanLepage commented 3 days ago

Sure, here it is: https://paste.glepage.com/upload/spider-bee-bison