blackjax-devs / blackjax

BlackJAX is a Bayesian Inference library designed for ease of use, speed and modularity.
https://blackjax-devs.github.io/blackjax/
Apache License 2.0
806 stars 105 forks source link

Enable shared mcmc parameters with tempered smc #694

Closed andrewdipper closed 3 months ago

andrewdipper commented 3 months ago

Allows for mcmc_parameters to be passed to the mcmc kernel as shared parameters prior to applying vmap. Thus shared parameters will not need to be duplicated for each individual particle.

This change filters mcmc_parameters by the length of the first dimension. Any parameters with length 1 are considered shared (note this is also acceptable in the case of just a single particle) and the rest are unshared. Shared parameters are then closed over before applying vmap so they don't need duplication. The behavior remains the same for shared parameters that are duplicated as they are just treated as unshared as before. This seems like the most reasonable way to handle shared parameters but let me know.

Related to https://github.com/blackjax-devs/blackjax/issues/690 cc @ciguaran

Some notes:

andrewdipper commented 3 months ago

The draft was mostly for if the tests should be modified - removed it