Closed PaulScemama closed 10 months ago
@junpenglao @albcab ran into an issue while going through the tests converting inference_loops
to new run_inference_algorithm
. This pertains to the discussion in the issue here.
This is what is happening: here is a univariate normal test for samplers. This particular case uses elliptical_slice
and passes in a dummy logdensity_fn
: lambda _: 1.0
. Now consider this self-contained code
import jax
import jax.numpy as jnp
import blackjax
algo = blackjax.elliptical_slice(lambda _: 1.0, **{"cov": jnp.array([2.0**2]), "mean": 1.0})
rng_key = jax.random.PRNGKey(123)
initial_position = 5.0
initial_state = blackjax.elliptical_slice.init(initial_position, lambda _: 1.0)
print(f"Initial State before try|except block {initial_state}")
# THIS IS BASICALLY WHAT `run_inference_algorithm` is doing.
try:
initial_state = blackjax.elliptical_slice.init(initial_state, lambda _: 1.0)
except TypeError:
pass
print(f"Initial State after try|except block: {initial_state}")
'Initial State before try|except block EllipSliceState(position=5.0, logdensity=1.0)'
'Initial State after try|except block: EllipSliceState(position=EllipSliceState(position=5.0, logdensity=1.0), logdensity=1.0)'
The comment indicates that the run_inference_algorithm
uses a try|except block to decipher if the initial_position_or_state
argument is an initial position or an initial state. If it is an initial position then the try block will work. If it is an initial state then the try block will fail due TypeError because of the assumption that the initial state object is different from the initial position object (which is a good assumption, indeed every init
algorithm returns a custom NamedTuple State
class while position is a PyTree).
The problem is that the init
for elliptical_slice
applies a logdensity_fn
to the position
argument. But in the test, a dummy logdensity_fn
does not take any inputs and returns 1.0 regardless: it is lambda _: 1.0
. This results in the unexpected and unwanted behavior where we can infinitely call init
on any argument and it will not error out.
Sorry if this is a bit disorganized 😅 but basically I want to know what you both @junpenglao @albcab would like to do about this. Should the test not include the dummy logdensity fn? This would be the quickest fix I think. Or do we need to rethink the run_inference_algorithm
?
If the latter, I still think this could be a useful option:
def run_inference_algorithm(
rng_key: PRNGKey,
inference_algorithm: Union[SamplingAlgorithm, VIAlgorithm],
num_steps: int,
initial_position: ArrayLikeTree = None,
initial_state: ArrayLikeTree = None,
) -> Tuple[State, State, Info]:
"""
Wrapper to run an inference algorithm.
Parameters
----------
rng_key : PRNGKey
The random state used by JAX's random numbers generator.
inference_algorithm : Union[SamplingAlgorithm, VIAlgorithm]
One of blackjax's sampling algorithms or variational inference algorithms.
num_steps : int
Number of learning steps.
initial_position : ArrayLikeTree, optional
The initial position to initialize the state of an inference algorithm,
by default None. Note that either `initial_position` or `initial_state`
must be passed in, but not both.
initial_state: ArrayLikeTree, optional
The initial state of the inference algorithm, by default None.
Note that either `initial_position` or `initial_state` must be passed in,
but not both.
Returns
-------
Tuple[State, State, Info]
1. The final state of the inference algorithm.
2. The history of states of the inference algorithm.
3. The history of the info of the inference algorithm.
"""
if (initial_position is None) == (initial_state is None):
raise ValueError(
"Either `initial_position` or `initial_state` must be specified, but not both."
)
if initial_state is None:
initial_state = inference_algorithm.init(initial_position)
keys = jax.random.split(rng_key, num_steps)
@jax.jit
def one_step(state, rng_key):
state, info = inference_algorithm.step(rng_key, state)
return state, (state, info)
final_state, (state_history, info_history) = jax.lax.scan(
one_step, initial_state, keys
)
return final_state, state_history, info_history
@junpenglao you mentioned that having initial_position
in the second arg makes vmapping easier, but if we wanted to vmap over initial positions couldn't we just:
jax.vmap(lambda rng_key, initial_position: run_inference_algorithm(rng_key, inference_algorithm, num_steps, initial_position=initial_position)(rng_keys, initial_positions)
I am also a beginner so I could be wrong; I am trying to learn so apologies in advance! Thanks for all the help :)
We should modify the test instead, try: logdensity_fn: lambda x: jnp.ones_like(x)
@junpenglao I think it is ready for review
Could you also update this one: https://github.com/blackjax-devs/blackjax/blob/3845635fc059a74f5742b96898692580fe750172/tests/adaptation/test_adaptation.py#L68
Nice work, thank you!
All modified and coverable lines are covered by tests :white_check_mark:
Comparison is base (
0a84b22
) 99.16% compared to head (f0ca07e
) 99.16%. Report is 1 commits behind head on main.
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.
Thank you for your contribution @PaulScemama! And congrats on your first PR to Blackjax :-)
Could you also open a PR to change the inference loop usage in sampling-book?
@junpenglao thank you! :) Thanks for all the guidance. The library is wonderful and I'm excited to continue contributing. I will open a PR to change the inference loop usage in the sampling book.
Thank you for opening a PR!
Closes #591. Implements a
run_inference_algorithm
wrapper function inblackjax/util.py
for convenience. See discussion in #591 for more details.A few important guidelines and requirements before we can merge your PR:
main
commit;pre-commit
is installed and configured on your machine, and you ran it before opening the PR;Consider opening a Draft PR if your work is still in progress but you would like some feedback from other contributors.