Closed andrewdipper closed 4 months ago
As such it'd be helpful to disable / select what info to store along the way such that the auxiliary info doesn't cause out of memory issues. Removing it altogether also doesn't seem ideal. I'd be happy to give a PR if you have an idea of what/how to best include/exclude the extra info.
Agree. I think numpryo also has a flag to control what get exposed. I think we will need to add a kwarg to https://github.com/blackjax-devs/blackjax/blob/40efb6cbc92a8074f8ac3a5d9710071f8d987537/blackjax/adaptation/window_adaptation.py#L244-L252
def return_all_adapt_info(state, info, adaptation_state):
return AdaptationInfo(state, info, adaptation_state)
def window_adaptation(
algorithm,
logdensity_fn: Callable,
is_mass_matrix_diagonal: bool = True,
initial_step_size: float = 1.0,
target_acceptance_rate: float = 0.80,
progress_bar: bool = False,
adaptation_info_fn: Callable = return_all_adapt_info
**extra_parameters,
)
And then add some utility function for filtering common info (e.g., return_warmup_sample=True
to return state
)
Really good point and thank you for the deep dive!! Feel free to send a PR.
Describe the issue as clearly as possible:
The scan in window_adaptation by default saves the
AdaptationInfo
for every sample along the way. This results in memory usage many times in excess of (num_samples)*(num_variables) and leads to out of memory issues. However, it looks like the last states are the only information necessary to performing the window adaptation.As such it'd be helpful to disable / select what info to store along the way such that the auxiliary info doesn't cause out of memory issues. Removing it altogether also doesn't seem ideal. I'd be happy to give a PR if you have an idea of what/how to best include/exclude the extra info.
I believe https://github.com/blackjax-devs/blackjax/issues/529 is the result of the same thing: The extra buffers are likely for storing the sample by sample info - I get similar outputs.
Thanks
Steps/code to reproduce the bug:
Expected result:
Error message:
Blackjax/JAX/jaxlib/Python version information:
Context for the issue:
I found this trying to reduce memory consumption for a pymc model sampled with blackjax - there's a similar issue there with storing extra info during the actual sampling process. With both fixes memory consumption and performance are initially looking more than comparable with pymc's numpyro sampler.