Closed andrewdipper closed 3 months ago
cc @ciguaran
Hi! regarding the first feature (depending on the previous parameter value) this is something we want to do since more complex tuning strategies like https://arxiv.org/abs/1005.1193 and https://arxiv.org/abs/1005.1193 require updating a probability distribution of parameters, that is passed from iteration to iteration. I have a draft for this but haven't been able to reproduce the results in https://arxiv.org/abs/1005.1193 yet to send a PR.
Regarding the shared vs unshared parameters, what you are asking for seems like something we want to be able to support. I wonder if there's a way to play with the broadcasting of vmap, like passing in some parameters of shape (n_particles, param_dimension) and some others (1,param_dimension), and forcing vmap to reuse the value in the case where the dimension doesn't match. If this can't be done then we will need to modify the code to support both types of parameters.
Ah, makes sense. Thanks for linking the paper, it looks like a nice direction. I was trying to replicate pymc's MH
SMC_KERNEL
to try to get a better understanding of things.
That's a nice non-breaking idea. It seems like that should work since it's just based on the shapes. I'll run some tests and see if I can get it working
@andrewdipper awesome! please let me know how that goes, looking forward to improving Blackjax's SMC usability.
It looks like filtering the parameters by the first dimension works well. I should have a draft pr for it soon
Closing this as the second part is addressed and the first is already in progress
Describe the issue as clearly as possible:
When computing the
new_parameter_override
there's only the new_state and info as input. This means there can't be a dependency on the previousparameter_override
. The caveat is thatmcmc_step_fn
can be hacked to transfer the mcmc parameters stored inStateWithParameterOverride
to the corresponding info ouput. However this results in duplications if num_mcmc_steps > 1 and it's theading additional data through functions that don't care about it. In addition it'd be inconsistent with how tempered handles itslmbda
parameter. I think there's also interference with the second issue below. https://github.com/blackjax-devs/blackjax/blob/83bc3a04a64ac3379cd220932e88b44578f2e8e5/blackjax/smc/inner_kernel_tuning.py#L73Secondly
state.parameter_override
is threaded through asmcmc_parameters
until it gets below where themcmc_kernel
is vmapped in tempered.py.https://github.com/blackjax-devs/blackjax/blob/83bc3a04a64ac3379cd220932e88b44578f2e8e5/blackjax/smc/tempered.py#L143-L149
Given that the vmap is done within tempered.py this is necessary for parameters that are different for each particle. However, there appears to be no way to avoid this for parameters that are shared among the particles. See https://blackjax-devs.github.io/sampling-book/algorithms/TemperedSMCWithOptimizedInnerKernel.html Here we are explicitly duplicating a potentially full covariance matrix for each particle. Given we're quadratic in the number of variables and linear in the number of particles this explodes quite quickly. Allowing the shared parameters to be closed over prior to the vmap would prevent this.
I might be missing something but I couldn't find a way to work around the issues without modifying the library. Essentially I wanted to be able to have a RMH kernel where I could scale the proposal distribution based on the acceptance rate. For the moment I'm splitting the shared / unshared mcmc parameters and passing
state.parameter_override
tomcmc_parameter_update_fn
https://github.com/blackjax-devs/blackjax/blob/83bc3a04a64ac3379cd220932e88b44578f2e8e5/blackjax/smc/inner_kernel_tuning.py#L59-L74
Steps/code to reproduce the bug:
Expected result:
Error message:
Blackjax/JAX/jaxlib/Python version information:
Context for the issue:
Unless I'm missing something it looks like inner_kernel_tuning for smc methods is pretty restrictive without modification.