blackjax-devs / blackjax

BlackJAX is a Bayesian Inference library designed for ease of use, speed and modularity.
https://blackjax-devs.github.io/blackjax/
Apache License 2.0
806 stars 105 forks source link

window_adaptation excessive memory usage #667

Closed andrewdipper closed 4 months ago

andrewdipper commented 4 months ago

Describe the issue as clearly as possible:

The scan in window_adaptation by default saves the AdaptationInfo for every sample along the way. This results in memory usage many times in excess of (num_samples)*(num_variables) and leads to out of memory issues. However, it looks like the last states are the only information necessary to performing the window adaptation.

As such it'd be helpful to disable / select what info to store along the way such that the auxiliary info doesn't cause out of memory issues. Removing it altogether also doesn't seem ideal. I'd be happy to give a PR if you have an idea of what/how to best include/exclude the extra info.

I believe https://github.com/blackjax-devs/blackjax/issues/529 is the result of the same thing: The extra buffers are likely for storing the sample by sample info - I get similar outputs.

Thanks

Steps/code to reproduce the bug:

#a roughed alternative one_step function (in window_adaptation.py) that completely removes the 
#extra info - None is not the ideal solution here

    def one_step(carry, xs):
        _, rng_key, adaptation_stage = xs
        state, adaptation_state = carry

        new_state, info = mcmc_kernel(
            rng_key,
            state,
            logdensity_fn,
            adaptation_state.step_size,
            adaptation_state.inverse_mass_matrix,
            **extra_parameters,
        )
        new_adaptation_state = adapt_step(
            adaptation_state,
            adaptation_stage,
            new_state.position,
            info.acceptance_rate,
        )

        return (
            (new_state, new_adaptation_state),
            None, #-removed-# AdaptationInfo(new_state, info, new_adaptation_state),
        )

Expected result:

...

Error message:

...

Blackjax/JAX/jaxlib/Python version information:

BlackJAX 0.1.dev494+g40efb6c.d20240511  # this is 1.2.1 with the jnp.clip PR removed
Python 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0]
Jax 0.4.25
Jaxlib 0.4.25

Context for the issue:

I found this trying to reduce memory consumption for a pymc model sampled with blackjax - there's a similar issue there with storing extra info during the actual sampling process. With both fixes memory consumption and performance are initially looking more than comparable with pymc's numpyro sampler.

junpenglao commented 4 months ago

As such it'd be helpful to disable / select what info to store along the way such that the auxiliary info doesn't cause out of memory issues. Removing it altogether also doesn't seem ideal. I'd be happy to give a PR if you have an idea of what/how to best include/exclude the extra info.

Agree. I think numpryo also has a flag to control what get exposed. I think we will need to add a kwarg to https://github.com/blackjax-devs/blackjax/blob/40efb6cbc92a8074f8ac3a5d9710071f8d987537/blackjax/adaptation/window_adaptation.py#L244-L252

def return_all_adapt_info(state, info, adaptation_state):
    return AdaptationInfo(state, info, adaptation_state)

 def window_adaptation( 
     algorithm, 
     logdensity_fn: Callable, 
     is_mass_matrix_diagonal: bool = True, 
     initial_step_size: float = 1.0, 
     target_acceptance_rate: float = 0.80, 
     progress_bar: bool = False, 
     adaptation_info_fn: Callable = return_all_adapt_info
     **extra_parameters, 
 )

And then add some utility function for filtering common info (e.g., return_warmup_sample=True to return state)

junpenglao commented 4 months ago

Really good point and thank you for the deep dive!! Feel free to send a PR.