Allows for mcmc_parameters to be passed to the mcmc kernel as shared parameters prior to applying vmap. Thus shared parameters will not need to be duplicated for each individual particle.
This change filters mcmc_parameters by the length of the first dimension. Any parameters with length 1 are considered shared (note this is also acceptable in the case of just a single particle) and the rest are unshared. Shared parameters are then closed over before applying vmap so they don't need duplication. The behavior remains the same for shared parameters that are duplicated as they are just treated as unshared as before. This seems like the most reasonable way to handle shared parameters but let me know.
the some of tests in test_kernel_compatibility.py can use 1 instead of self.n_particles though the correct choice arguably depends on the meaning of the parameter - I changed the ones in test_inner_kernel_tuning.py that were apparently shared
I couldn't find a way to filter shared/unshared parameters with jax.tree_util functions - None always remains in the structure but maybe I missed something
Allows for
mcmc_parameters
to be passed to the mcmc kernel as shared parameters prior to applying vmap. Thus shared parameters will not need to be duplicated for each individual particle.This change filters
mcmc_parameters
by the length of the first dimension. Any parameters with length 1 are considered shared (note this is also acceptable in the case of just a single particle) and the rest are unshared. Shared parameters are then closed over before applying vmap so they don't need duplication. The behavior remains the same for shared parameters that are duplicated as they are just treated as unshared as before. This seems like the most reasonable way to handle shared parameters but let me know.Related to https://github.com/blackjax-devs/blackjax/issues/690 cc @ciguaran
Some notes: