Open pawel-czyz opened 1 month ago
Thank you for the detailed write-up, much appreciated. And yes, a contribution will be very welcome!
Regarding the design choice, jax.vmap
would be the answer to both of your question.
I have not read in detail of your blog post, but just wondering if you have compared your implementation with TFP, which i am a bit more familiar with.
BTW, I am a huge fan of parallel tempering - very excited about this! Looking forward to your PR!
Same here, I had planned to do it at some point but I've not been able to commit the time to :D very happy someone is doing it!
Design choice wise, I actually do not think it would be a good idea to vmap everything at the lower level, in particular in sight of being able to do proper sharding. IMO there are two components to the method: 1) swap kernel, 2) reversible vs non-reversible application (this can probably be vmapped as it's onlt a chain on the indices conditionally on the log likelihood values).
Once that's done, the choice of parallelism for the state chains is very much user driven and it's hard to enforce a coherent interface supporting all JAX models.
IIUC, swapping kernel is basically swapping parameter (eg. step size), which means you update the input parameter with some advance indexing. The base kernel would remain the same like step = jax.vmap(kernel.step)
Thank you for your kind feedback and all the suggestions!
I have not read in detail of your blog post, but just wondering if you have compared your implementation with TFP, which i am a bit more familiar with.
Thanks, I have not been aware that there exists a TFP implementation! I like it, the major differences seem to be:
jax.vmap
is preferred over batching, isn't it? I.e., it's not possible to get log_p
return a batch of log-PDFs compatible with a "batched" kernel?)Once that's done, the choice of parallelism for the state chains is very much user driven and it's hard to enforce a coherent interface supporting all JAX models.
I think this is a very good point. I think it'd be convenient to have a utility function, allowing the end user to quickly build a reasonable (even if not optimally sharded) parallel tempering kernel out of an existing one, using jax.vmap
. I've been thinking about something along these lines:
import jax
import jax.random as jrandom
import jax.numpy as jnp
import blackjax
def init(
init_fn,
positions,
log_target,
log_reference,
inverse_temperatures,
):
def create_tempered_log_p(inverse_temperature):
def log_p(x):
return inverse_temperature * log_target(x) + (1.0 - inverse_temperature) * log_reference(x)
return log_p
def init_fn_temp(position, inverse_temperature):
return init_fn(position, create_tempered_log_p(inverse_temperature))
return jax.vmap(init_fn_temp)(positions, inverse_temperatures)
def build_kernel(
base_kernel_fn,
log_target,
log_reference,
inverse_temperatures,
parameters,
):
def create_tempered_log_p(inverse_temperature):
def log_p(x):
return inverse_temperature * log_target(x) + (1.0 - inverse_temperature) * log_reference(x)
return log_p
def kernel(rng_key, state, inverse_temperature, parameter):
return base_kernel_fn(rng_key, state, create_tempered_log_p(inverse_temperature), parameter)
n_chains = inverse_temperatures.shape[0]
def step_fn(
rng_key,
state,
):
keys = jrandom.split(rng_key, n_chains)
return jax.vmap(kernel)(
keys,
state,
inverse_temperatures,
parameters,
)
return step_fn
def log_p(x):
return -(jnp.sum(jnp.square(x)) + jnp.sum(jnp.power(x, 4)))
def log_ref(x):
return -jnp.sum(jnp.square(x))
n_chains = 10
inverse_temperatures = 0.9 ** jnp.linspace(0, 1, n_chains)
initial_positions = jnp.ones((n_chains, 2))
parameters = jnp.linspace(0.1, 1, n_chains)
init_state = init(
blackjax.mcmc.mala.init,
initial_positions,
log_p,
log_ref,
inverse_temperatures,
)
kernel = build_kernel(
blackjax.mcmc.mala.build_kernel(),
log_p,
log_ref,
inverse_temperatures,
parameters,
)
rng_key = jrandom.PRNGKey(42)
new_state, info = kernel(rng_key, init_state)
Please, let me know what you think! Also:
Question: how to optimally pass the parameters to the individual kernels? This solution works only if each kernel has a single parameter. This parameter could be dictionary-valued, though, allowing the users to write wrappers around kernel initialisers. E.g., one could create a wrapper around the the HMC kernel builder function, which passes the step size, mass matrix and numbers of steps as a dictionary, but I'm not sure how convenient it is for the end users.
IIUC, swapping kernel is basically swapping parameter (eg. step size), which means you update the input parameter with some advance indexing. The base kernel would remain the same like step = jax.vmap(kernel.step)
I see! This potentially can lead to better parallelism, but I think it'd be easier for me to swap the states $x_i$, rather than swapping parameters and temperatures. One of the reasons is that I do not have to build the explicit index process, then, and can easily record the rejection rates, which allows one to tune the tempering schedule – would such a solution be still fine? :slightly_smiling_face:
Question: I'm also not sure about the best design choice regarding the composed kernels. Currently each kernel $K$ records some information (e.g., acceptance rates, divergences, ...). In this case, we have a kernel $K_\text{ind}$ (applying kernels $K_i$ independently to individual components $xi$) and a kernel $K\text{swap}$, applied to the joint state $(x1, \dotsc, x{T-1})$. Should I define the information object to be a named tuple constructed out of the information of kernel $K_\text{ind}$, which builds upon $Ki$, and the auxiliary information from $K\text{swap}$? In this case, how should one handle the auxiliary information coming from an application of $K\text{ind}$ e.g., 3 times for every swap attempt? (In other words, the joint kernel can be $ K\text{ind}^3 K_\text{swap}$, resulting in 3 times longer information about the independent moves...)
More generally, Question: Are there existing utilities for combining the kernels? I know that the Metropolis-within-Gibbs tutorial constructs explicitly a kernel, but in theory, one could imagine an operation of composing the kernels corresponding to the updates of different variables. Such utilities could be useful not only in the context of parallel tempering or Metropolis-within-Gibbs, but also for building non-reversible kernels employing some information about the given target. For example, when sampling phylogenetic trees one has several kernels (e.g., kernels changing the tree topology and kernels permuting the taxa between different nodes) and they are combined either by composition $K = K_1 K_2 K_3$ or by a mixture $K = \frac{1}{3}(K_1 + K_2 + K_3)$.
I think for simplicity, let's start with building the functionality assuming we are using the same base kernel (e.g., HMC) with different parameter (e.g., step_size)
Presentation of the new sampler
Parallel tempering, known also as replica exchange MCMC, maintains $K$ Markov chains at different temperatures, ranging from $\pi0$ (the reference distribution, for example the prior) to the target distribution $\pi =: \pi{K-1}$. Apart from local exploration kernels, targeting each distribution individually, it includes swap kernels, trying to switch states from different chains, hence targeting the distribution $$(x0, \dotsc, x{K-1})\mapsto \prod_{i=0}^{K-1} \pi_i(xi),$$ defined on the spaces $\mathcal X^K = \mathcal X\times \cdots \times \mathcal X$. By retaining the samples from only the last coordinate, it allows one to sample from $\pi = \pi{K-1}$.
Similarly to sequential Monte Carlo (SMC) samplers, this strategy can be highly efficient to sample from multimodal posteriors. Modern variant of parallel tempering, called non-reversible parallel tempering (NRPT), achieves the state-of-the-art performance in sampling from complex high-dimensional distributions. NRPT works with both discrete and continuous spaces, and allows one to leverage preliminary runs to tune the tempering schedule.
Resources
How does it compare to other algorithms in blackjax?
Compared with a single-chain MCMC sampler:
Compared with a tempered SMC sampler:
Where does it fit in blackjax
BlackJAX offers a large collection of MCMC kernels. They can be leveraged to build non-reversible parallel tempering MCMC samplers, exploring different temperatures simultaneously, which leads to faster mixing and allows one to sample from multimodal posteriors. NRPT ac
Are you willing to open a PR?
Yes. I have a prototype implemented in a blog post, which I would be willing to refactor and contribute.
I am however unsure about two design choices:
jax.vmap
, rather than afor
loop. I guess this problem is less apparent in tempered SMC samplers, where at temperature $T$ all particles are moved by the same kernel $K_T$. (Even though kernels $K_T$ may differ for different temperatures.)