blackjax-devs / blackjax

BlackJAX is a Bayesian Inference library designed for ease of use, speed and modularity.
https://blackjax-devs.github.io/blackjax/
Apache License 2.0
829 stars 106 forks source link

Add non-reversible parallel tempering #740

Open pawel-czyz opened 1 month ago

pawel-czyz commented 1 month ago

Presentation of the new sampler

Parallel tempering, known also as replica exchange MCMC, maintains $K$ Markov chains at different temperatures, ranging from $\pi0$ (the reference distribution, for example the prior) to the target distribution $\pi =: \pi{K-1}$. Apart from local exploration kernels, targeting each distribution individually, it includes swap kernels, trying to switch states from different chains, hence targeting the distribution $$(x0, \dotsc, x{K-1})\mapsto \prod_{i=0}^{K-1} \pi_i(xi),$$ defined on the spaces $\mathcal X^K = \mathcal X\times \cdots \times \mathcal X$. By retaining the samples from only the last coordinate, it allows one to sample from $\pi = \pi{K-1}$.

Similarly to sequential Monte Carlo (SMC) samplers, this strategy can be highly efficient to sample from multimodal posteriors. Modern variant of parallel tempering, called non-reversible parallel tempering (NRPT), achieves the state-of-the-art performance in sampling from complex high-dimensional distributions. NRPT works with both discrete and continuous spaces, and allows one to leverage preliminary runs to tune the tempering schedule.

Resources

How does it compare to other algorithms in blackjax?

Compared with a single-chain MCMC sampler:

Compared with a tempered SMC sampler:

Where does it fit in blackjax

BlackJAX offers a large collection of MCMC kernels. They can be leveraged to build non-reversible parallel tempering MCMC samplers, exploring different temperatures simultaneously, which leads to faster mixing and allows one to sample from multimodal posteriors. NRPT ac

Are you willing to open a PR?

Yes. I have a prototype implemented in a blog post, which I would be willing to refactor and contribute.

I am however unsure about two design choices:

  1. Non-reversible parallel tempering requires application of local kernels, $K_i$, to individual coordinate chains $x_i \in \mathcal X$ (corresponding to tempered distributions $\pi_i$). What would be the best practice of storing and applying different kernels? For example, one may be willing to use HMC with a large step size for sampling from $\pi0$ (or even sample from the prior directly, if e.g., using a probabilistic programming language) and HMC with a small step size to sample from $\pi=\pi{K-1}$. In my prototype I employ kernels from the same family to use jax.vmap, rather than a for loop. I guess this problem is less apparent in tempered SMC samplers, where at temperature $T$ all particles are moved by the same kernel $K_T$. (Even though kernels $K_T$ may differ for different temperatures.)
  2. How to parallelize the computation? Pigeons.jl uses MPI communication between different machines to study very high-dimensional problems. Should a BlackJAX version use some version of sharding or could I keep it simple and rely on built-in parallelism?
junpenglao commented 1 month ago

Thank you for the detailed write-up, much appreciated. And yes, a contribution will be very welcome!

Regarding the design choice, jax.vmap would be the answer to both of your question.

I have not read in detail of your blog post, but just wondering if you have compared your implementation with TFP, which i am a bit more familiar with.

junpenglao commented 1 month ago

BTW, I am a huge fan of parallel tempering - very excited about this! Looking forward to your PR!

AdrienCorenflos commented 1 month ago

Same here, I had planned to do it at some point but I've not been able to commit the time to :D very happy someone is doing it!

Design choice wise, I actually do not think it would be a good idea to vmap everything at the lower level, in particular in sight of being able to do proper sharding. IMO there are two components to the method: 1) swap kernel, 2) reversible vs non-reversible application (this can probably be vmapped as it's onlt a chain on the indices conditionally on the log likelihood values).

Once that's done, the choice of parallelism for the state chains is very much user driven and it's hard to enforce a coherent interface supporting all JAX models.

junpenglao commented 1 month ago

IIUC, swapping kernel is basically swapping parameter (eg. step size), which means you update the input parameter with some advance indexing. The base kernel would remain the same like step = jax.vmap(kernel.step)

pawel-czyz commented 1 month ago

Thank you for your kind feedback and all the suggestions!

I have not read in detail of your blog post, but just wondering if you have compared your implementation with TFP, which i am a bit more familiar with.

Thanks, I have not been aware that there exists a TFP implementation! I like it, the major differences seem to be:

  1. Allow different step sizes by using batching. (If I understand BlackJAX philosophy, jax.vmap is preferred over batching, isn't it? I.e., it's not possible to get log_p return a batch of log-PDFs compatible with a "batched" kernel?)
  2. TFP allows also other kinds of parallel tempering, with different swapping schemes. I think Adrien's suggestion will be here useful, making this essentially a variable argument resulting in different kernels.
  3. I don't think TFP records rejection statistics and uses them to tune up the optimisation schedule.

Once that's done, the choice of parallelism for the state chains is very much user driven and it's hard to enforce a coherent interface supporting all JAX models.

I think this is a very good point. I think it'd be convenient to have a utility function, allowing the end user to quickly build a reasonable (even if not optimally sharded) parallel tempering kernel out of an existing one, using jax.vmap. I've been thinking about something along these lines:


import jax
import jax.random as jrandom
import jax.numpy as jnp
import blackjax

def init(
    init_fn,
    positions,
    log_target,
    log_reference,
    inverse_temperatures,
):
    def create_tempered_log_p(inverse_temperature):
        def log_p(x):
            return inverse_temperature * log_target(x) + (1.0 - inverse_temperature) * log_reference(x)
        return log_p

    def init_fn_temp(position, inverse_temperature):
        return init_fn(position, create_tempered_log_p(inverse_temperature))

    return jax.vmap(init_fn_temp)(positions, inverse_temperatures)

def build_kernel(
    base_kernel_fn,
    log_target,
    log_reference,
    inverse_temperatures,
    parameters,
):
    def create_tempered_log_p(inverse_temperature):
        def log_p(x):
            return inverse_temperature * log_target(x) + (1.0 - inverse_temperature) * log_reference(x)
        return log_p

    def kernel(rng_key, state, inverse_temperature, parameter):
        return base_kernel_fn(rng_key, state, create_tempered_log_p(inverse_temperature), parameter)

    n_chains = inverse_temperatures.shape[0]

    def step_fn(
        rng_key,
        state,
    ):
        keys = jrandom.split(rng_key, n_chains)
        return jax.vmap(kernel)(
            keys,
            state,
            inverse_temperatures,
            parameters,
        )

    return step_fn

def log_p(x):
    return -(jnp.sum(jnp.square(x)) + jnp.sum(jnp.power(x, 4)))

def log_ref(x):
    return -jnp.sum(jnp.square(x))

n_chains = 10
inverse_temperatures = 0.9 ** jnp.linspace(0, 1, n_chains)
initial_positions = jnp.ones((n_chains, 2))
parameters = jnp.linspace(0.1, 1, n_chains)

init_state = init(
    blackjax.mcmc.mala.init,
    initial_positions,
    log_p,
    log_ref,
    inverse_temperatures,
)

kernel = build_kernel(
    blackjax.mcmc.mala.build_kernel(),
    log_p,
    log_ref,
    inverse_temperatures,
    parameters,
)

rng_key = jrandom.PRNGKey(42)
new_state, info = kernel(rng_key, init_state)

Please, let me know what you think! Also:

Question: how to optimally pass the parameters to the individual kernels? This solution works only if each kernel has a single parameter. This parameter could be dictionary-valued, though, allowing the users to write wrappers around kernel initialisers. E.g., one could create a wrapper around the the HMC kernel builder function, which passes the step size, mass matrix and numbers of steps as a dictionary, but I'm not sure how convenient it is for the end users.

IIUC, swapping kernel is basically swapping parameter (eg. step size), which means you update the input parameter with some advance indexing. The base kernel would remain the same like step = jax.vmap(kernel.step)

I see! This potentially can lead to better parallelism, but I think it'd be easier for me to swap the states $x_i$, rather than swapping parameters and temperatures. One of the reasons is that I do not have to build the explicit index process, then, and can easily record the rejection rates, which allows one to tune the tempering schedule – would such a solution be still fine? :slightly_smiling_face:

Question: I'm also not sure about the best design choice regarding the composed kernels. Currently each kernel $K$ records some information (e.g., acceptance rates, divergences, ...). In this case, we have a kernel $K_\text{ind}$ (applying kernels $K_i$ independently to individual components $xi$) and a kernel $K\text{swap}$, applied to the joint state $(x1, \dotsc, x{T-1})$. Should I define the information object to be a named tuple constructed out of the information of kernel $K_\text{ind}$, which builds upon $Ki$, and the auxiliary information from $K\text{swap}$? In this case, how should one handle the auxiliary information coming from an application of $K\text{ind}$ e.g., 3 times for every swap attempt? (In other words, the joint kernel can be $ K\text{ind}^3 K_\text{swap}$, resulting in 3 times longer information about the independent moves...)

More generally, Question: Are there existing utilities for combining the kernels? I know that the Metropolis-within-Gibbs tutorial constructs explicitly a kernel, but in theory, one could imagine an operation of composing the kernels corresponding to the updates of different variables. Such utilities could be useful not only in the context of parallel tempering or Metropolis-within-Gibbs, but also for building non-reversible kernels employing some information about the given target. For example, when sampling phylogenetic trees one has several kernels (e.g., kernels changing the tree topology and kernels permuting the taxa between different nodes) and they are combined either by composition $K = K_1 K_2 K_3$ or by a mixture $K = \frac{1}{3}(K_1 + K_2 + K_3)$.

junpenglao commented 1 month ago

I think for simplicity, let's start with building the functionality assuming we are using the same base kernel (e.g., HMC) with different parameter (e.g., step_size)