blackjax-devs / blackjax

BlackJAX is a Bayesian Inference library designed for ease of use, speed and modularity.
https://blackjax-devs.github.io/blackjax/
Apache License 2.0
806 stars 105 forks source link

smc inner_kernel_tuning forces stateless mcmc_parameters / vmap over shared parameters #690

Closed andrewdipper closed 3 months ago

andrewdipper commented 4 months ago

Describe the issue as clearly as possible:

When computing the new_parameter_override there's only the new_state and info as input. This means there can't be a dependency on the previous parameter_override. The caveat is that mcmc_step_fn can be hacked to transfer the mcmc parameters stored in StateWithParameterOverride to the corresponding info ouput. However this results in duplications if num_mcmc_steps > 1 and it's theading additional data through functions that don't care about it. In addition it'd be inconsistent with how tempered handles its lmbda parameter. I think there's also interference with the second issue below. https://github.com/blackjax-devs/blackjax/blob/83bc3a04a64ac3379cd220932e88b44578f2e8e5/blackjax/smc/inner_kernel_tuning.py#L73

Secondly state.parameter_override is threaded through as mcmc_parameters until it gets below where the mcmc_kernel is vmapped in tempered.py.

https://github.com/blackjax-devs/blackjax/blob/83bc3a04a64ac3379cd220932e88b44578f2e8e5/blackjax/smc/tempered.py#L143-L149

Given that the vmap is done within tempered.py this is necessary for parameters that are different for each particle. However, there appears to be no way to avoid this for parameters that are shared among the particles. See https://blackjax-devs.github.io/sampling-book/algorithms/TemperedSMCWithOptimizedInnerKernel.html Here we are explicitly duplicating a potentially full covariance matrix for each particle. Given we're quadratic in the number of variables and linear in the number of particles this explodes quite quickly. Allowing the shared parameters to be closed over prior to the vmap would prevent this.

I might be missing something but I couldn't find a way to work around the issues without modifying the library. Essentially I wanted to be able to have a RMH kernel where I could scale the proposal distribution based on the acceptance rate. For the moment I'm splitting the shared / unshared mcmc parameters and passing state.parameter_override to mcmc_parameter_update_fn

https://github.com/blackjax-devs/blackjax/blob/83bc3a04a64ac3379cd220932e88b44578f2e8e5/blackjax/smc/inner_kernel_tuning.py#L59-L74

Steps/code to reproduce the bug:

#not this kind of bug

Expected result:

#not this kind of bug

Error message:

#not this kind of bug

Blackjax/JAX/jaxlib/Python version information:

>>> blackjax.__version__
'1.2.1'
>>> sys.version
'3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0]'
>>> jax.__version__
'0.4.28'
>>> jaxlib.__version__
'0.4.28'

Context for the issue:

Unless I'm missing something it looks like inner_kernel_tuning for smc methods is pretty restrictive without modification.

junpenglao commented 4 months ago

cc @ciguaran

ciguaran commented 4 months ago

Hi! regarding the first feature (depending on the previous parameter value) this is something we want to do since more complex tuning strategies like https://arxiv.org/abs/1005.1193 and https://arxiv.org/abs/1005.1193 require updating a probability distribution of parameters, that is passed from iteration to iteration. I have a draft for this but haven't been able to reproduce the results in https://arxiv.org/abs/1005.1193 yet to send a PR.

Regarding the shared vs unshared parameters, what you are asking for seems like something we want to be able to support. I wonder if there's a way to play with the broadcasting of vmap, like passing in some parameters of shape (n_particles, param_dimension) and some others (1,param_dimension), and forcing vmap to reuse the value in the case where the dimension doesn't match. If this can't be done then we will need to modify the code to support both types of parameters.

andrewdipper commented 4 months ago

Ah, makes sense. Thanks for linking the paper, it looks like a nice direction. I was trying to replicate pymc's MH SMC_KERNEL to try to get a better understanding of things.

That's a nice non-breaking idea. It seems like that should work since it's just based on the shapes. I'll run some tests and see if I can get it working

ciguaran commented 3 months ago

@andrewdipper awesome! please let me know how that goes, looking forward to improving Blackjax's SMC usability.

andrewdipper commented 3 months ago

It looks like filtering the parameters by the first dimension works well. I should have a draft pr for it soon

andrewdipper commented 3 months ago

Closing this as the second part is addressed and the first is already in progress