rlouf / mcx

Express & compile probabilistic programs for performant inference on CPU & GPU. Powered by JAX.
https://rlouf.github.io/mcx
Apache License 2.0
325 stars 17 forks source link

Run warmup for each chain #21

Closed rlouf closed 3 years ago

rlouf commented 4 years ago

As per discussion in issue #19, it is likely that chains starting at very different positions in the state space will not sample efficiently with the same kernel parameters. It would thus make sense to run one warmup per chain when we warmup.

The sampler would only require a very light modification:

@partial(jax.jit, static_argnums=(1,))
def move_chain(rng_key, kernel, state):
     new_state = kernel(rng_key, state)
     return new_state

jax.vmap(move_chain)(keys, kernels, states)

And we will need to vmap the warmup function onto the initial states and parameters.

I am opening this issue to keep track of the idea and welcome discussions until completion.

rlouf commented 3 years ago

We have a working prototype in #29. All chains now have independent parameters post-warmup, they only share the same number of integration steps unless the user specifies an array with one value per chain. We build one kernel per chain.