blackjax-devs / blackjax

BlackJAX is a Bayesian Inference library designed for ease of use, speed and modularity.
https://blackjax-devs.github.io/blackjax/
Apache License 2.0
806 stars 105 forks source link

Resolve: Add progress bar to run_inference_algorithm #614

Closed PaulScemama closed 10 months ago

PaulScemama commented 10 months ago

Closes #610. Adds progress bar to run_inference_algorithm.

A few important guidelines and requirements before we can merge your PR:

Consider opening a Draft PR if your work is still in progress but you would like some feedback from other contributors.

codecov[bot] commented 10 months ago

Codecov Report

All modified and coverable lines are covered by tests :white_check_mark:

Comparison is base (540db41) 99.18% compared to head (9b4daf7) 99.22%.

Additional details and impacted files ```diff @@ Coverage Diff @@ ## main #614 +/- ## ========================================== + Coverage 99.18% 99.22% +0.04% ========================================== Files 57 57 Lines 2576 2581 +5 ========================================== + Hits 2555 2561 +6 + Misses 21 20 -1 ```

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

PaulScemama commented 10 months ago

@junpenglao I am trying to figure out why some of the tests are failing -- weird that the details say the test is failing in a file that does not use run_inference_algorithm.

Also, should I adapt some tests to use the progress_bar?

junpenglao commented 10 months ago

Looks like there is a flaky test - could you change the rng_key of that test until it pass, and add a comment that it is a flaky test?

junpenglao commented 10 months ago

@reubenharry found a bug in the function, cause by this line: https://github.com/blackjax-devs/blackjax/blob/540db419f0ccddb0368443049492b6c0448fe273/blackjax/util.py#L174

Could you remove this line and add a test? Something like:

import jax
import jax.numpy as jnp
from blackjax.mcmc.hmc import hmc
from blackjax.util import run_inference_algorithm

def logdensity_fn(x):
    return -0.5 * jnp.sum(jnp.square(x))

alg = hmc(
    logdensity_fn=logdensity_fn,
    inverse_mass_matrix=jnp.eye(2),
    step_size=1.0,
    num_integration_steps=1000,
)

_ = run_inference_algorithm(
    rng_key=jax.random.PRNGKey(0),
    initial_state_or_position=jnp.array([1.0, 1.0]),
    inference_algorithm=alg,
    num_steps=10,
    progress_bar=True)
PaulScemama commented 10 months ago

@junpenglao I think you meant to tag me instead of @reubenharry? But yes definitely!, sorry about that. Should I start a new tests file, maybe test_util.py? Or where would you like me to put such a test?

junpenglao commented 10 months ago

test_util.py sounds good.

PaulScemama commented 10 months ago

@junpenglao I have to step away for the rest of the day. Just added a test for run_inference_loop but haven't done anything with the flaky tests yet. I'll be back tomorrow. Thanks for all the help!

junpenglao commented 10 months ago

The flaky test is not failing now, so I am going to go ahead and merge it. Thanks!