blackjax-devs / blackjax

BlackJAX is a Bayesian Inference library designed for ease of use, speed and modularity.
https://blackjax-devs.github.io/blackjax/
Apache License 2.0
806 stars 105 forks source link

Resolve 'Functions to run kernels' #598

Closed PaulScemama closed 10 months ago

PaulScemama commented 10 months ago

Thank you for opening a PR!

Closes #591. Implements a run_inference_algorithm wrapper function in blackjax/util.py for convenience. See discussion in #591 for more details.

A few important guidelines and requirements before we can merge your PR:

Consider opening a Draft PR if your work is still in progress but you would like some feedback from other contributors.

PaulScemama commented 10 months ago

@junpenglao @albcab ran into an issue while going through the tests converting inference_loops to new run_inference_algorithm. This pertains to the discussion in the issue here.

This is what is happening: here is a univariate normal test for samplers. This particular case uses elliptical_slice and passes in a dummy logdensity_fn: lambda _: 1.0. Now consider this self-contained code

import jax
import jax.numpy as jnp

import blackjax

algo = blackjax.elliptical_slice(lambda _: 1.0, **{"cov": jnp.array([2.0**2]), "mean": 1.0})
rng_key = jax.random.PRNGKey(123)

initial_position = 5.0
initial_state = blackjax.elliptical_slice.init(initial_position, lambda _: 1.0)
print(f"Initial State before try|except block {initial_state}")

# THIS IS BASICALLY WHAT `run_inference_algorithm` is doing. 
try:
    initial_state = blackjax.elliptical_slice.init(initial_state, lambda _: 1.0)
except TypeError:
    pass

print(f"Initial State after try|except block: {initial_state}")

'Initial State before try|except block EllipSliceState(position=5.0, logdensity=1.0)'
'Initial State after try|except block: EllipSliceState(position=EllipSliceState(position=5.0, logdensity=1.0), logdensity=1.0)'

The comment indicates that the run_inference_algorithm uses a try|except block to decipher if the initial_position_or_state argument is an initial position or an initial state. If it is an initial position then the try block will work. If it is an initial state then the try block will fail due TypeError because of the assumption that the initial state object is different from the initial position object (which is a good assumption, indeed every init algorithm returns a custom NamedTuple State class while position is a PyTree).

The problem is that the init for elliptical_slice applies a logdensity_fn to the position argument. But in the test, a dummy logdensity_fn does not take any inputs and returns 1.0 regardless: it is lambda _: 1.0. This results in the unexpected and unwanted behavior where we can infinitely call init on any argument and it will not error out.

Sorry if this is a bit disorganized 😅 but basically I want to know what you both @junpenglao @albcab would like to do about this. Should the test not include the dummy logdensity fn? This would be the quickest fix I think. Or do we need to rethink the run_inference_algorithm?

If the latter, I still think this could be a useful option:

def run_inference_algorithm(
    rng_key: PRNGKey,
    inference_algorithm: Union[SamplingAlgorithm, VIAlgorithm],
    num_steps: int,
    initial_position: ArrayLikeTree = None,
    initial_state: ArrayLikeTree = None,
) -> Tuple[State, State, Info]:
    """
    Wrapper to run an inference algorithm.

    Parameters
    ----------
    rng_key : PRNGKey
        The random state used by JAX's random numbers generator.
    inference_algorithm : Union[SamplingAlgorithm, VIAlgorithm]
        One of blackjax's sampling algorithms or variational inference algorithms.
    num_steps : int
        Number of learning steps.
    initial_position : ArrayLikeTree, optional
        The initial position to initialize the state of an inference algorithm,
        by default None. Note that either `initial_position` or `initial_state`
        must be passed in, but not both.
    initial_state: ArrayLikeTree, optional
        The initial state of the inference algorithm, by default None.
        Note that either `initial_position` or `initial_state` must be passed in,
        but not both.

    Returns
    -------
    Tuple[State, State, Info]
        1. The final state of the inference algorithm.
        2. The history of states of the inference algorithm.
        3. The history of the info of the inference algorithm.
    """
    if (initial_position is None) == (initial_state is None):
        raise ValueError(
            "Either `initial_position` or `initial_state` must be specified, but not both."
        )
    if initial_state is None:
        initial_state =  inference_algorithm.init(initial_position)

    keys = jax.random.split(rng_key, num_steps)

    @jax.jit
    def one_step(state, rng_key):
        state, info = inference_algorithm.step(rng_key, state)
        return state, (state, info)

    final_state, (state_history, info_history) = jax.lax.scan(
        one_step, initial_state, keys
    )
    return final_state, state_history, info_history

@junpenglao you mentioned that having initial_position in the second arg makes vmapping easier, but if we wanted to vmap over initial positions couldn't we just:

jax.vmap(lambda rng_key, initial_position: run_inference_algorithm(rng_key, inference_algorithm, num_steps, initial_position=initial_position)(rng_keys, initial_positions)

I am also a beginner so I could be wrong; I am trying to learn so apologies in advance! Thanks for all the help :)

junpenglao commented 10 months ago

We should modify the test instead, try: logdensity_fn: lambda x: jnp.ones_like(x)

PaulScemama commented 10 months ago

@junpenglao I think it is ready for review

junpenglao commented 10 months ago

Could you also update this one: https://github.com/blackjax-devs/blackjax/blob/3845635fc059a74f5742b96898692580fe750172/tests/adaptation/test_adaptation.py#L68

junpenglao commented 10 months ago

Nice work, thank you!

codecov[bot] commented 10 months ago

Codecov Report

All modified and coverable lines are covered by tests :white_check_mark:

Comparison is base (0a84b22) 99.16% compared to head (f0ca07e) 99.16%. Report is 1 commits behind head on main.

Additional details and impacted files ```diff @@ Coverage Diff @@ ## main #598 +/- ## ======================================= Coverage 99.16% 99.16% ======================================= Files 54 54 Lines 2513 2527 +14 ======================================= + Hits 2492 2506 +14 Misses 21 21 ```

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

junpenglao commented 10 months ago

Thank you for your contribution @PaulScemama! And congrats on your first PR to Blackjax :-)

Could you also open a PR to change the inference loop usage in sampling-book?

PaulScemama commented 10 months ago

@junpenglao thank you! :) Thanks for all the guidance. The library is wonderful and I'm excited to continue contributing. I will open a PR to change the inference loop usage in the sampling book.