tensorflow / probability

Probabilistic reasoning and statistical analysis in TensorFlow
Apache License 2.0
4.26k stars 1.1k forks source link

Optimal way to run multiple chains for a bayesian neural network trained with HMC (tfp.mcmc.HamiltonianMonteCarlo) #1496

Open reneaas opened 2 years ago

reneaas commented 2 years ago


I'm working on a master's thesis where we want to sample the exact posterior of a bayesian neural network using HMC and the No-U-Turn sampler in regression tasks. The current code implementation does not optimally utilize the fact that several chains can be run simultaneously when calling tfp.mcmc.sample_chain.

I'll walk through the code implementation. First, the weights of the network is stored in a list as

weights = [kernel:0, bias:0, kernel:1, bias:1, ...]

The kernels have shape [num_chains, n, m] and the biases have shape [num_chains, m] to allow for several chains run simultaneously as per the docs of Tensorflow Probability.

The log prior is defined as

def log_prior(weights, lamb=1e-3):
    kernel = weights[::2]
    bias = weights[1::2]
    res = 0
    for w, b in zip(kernel, bias):
        res += tf.reduce_sum(w ** 2, axis=(-1,-2))
        res += tf.reduce_sum(b ** 2, axis=-1)
    return -0.5 * lamb * res

The log likelihood is defined as

def log_likelihood(x, y, weights, activation=tf.nn.relu):
    kernel = weights[::2]
    bias = weights[1::2]
    for w, b in zip(kernel[:-1], bias[:-1]):
        x = activation(tf.matmul(x, w) + b[..., None, :])
    y_pred = tf.matmul(x, kernel[-1]) + bias[-1][..., None, :]

    return -0.5 * tf.reduce_sum((y_pred - y) ** 2, axis=(-1,-2))

And the target log probability function is defined by

def get_target_log_prob_fn(x, y):
    def target_log_prob_fn(*weights):
        return log_prior(weights) + log_likelihood(x, y, weights)
    return target_log_prob_fn    

Thus, given training features x of shape [num_points, num_features] and training targets y of shape [num_points, num_outputs], we can extract the target log probability function as

target_log_prob_fn = get_target_log_prob_fn(x, y)

While this code provide adequate results in terms of a proper non-linear regression model, it does not run well when num_chains > 1. In fact, the case num_chains = 1 with num_results = 100 runs significantly faster than num_chains = 10 with num_results = 10, even though the produce the exact same number of results.

As an example of model, we can create it with the following function:

def get_weights(layers, num_chains):
    weights = []
    for n, m in zip(layers[:-1], layers[1:]):
        w = tf.random.normal(shape=(num_chains, n, m))
        b = tf.random.normal(shape=(num_chains, m))
        weights.extend([w, b])
    return weights

And run the chain with adaptive HMC like so:

num_results = 100
num_burnin_steps = 1000
num_chains = 10
layers = [input_sz, 10, 10, output_sz]
weights = get_weights(layers, num_chains)

kernel = tfp.mcmc.HamiltonianMonteCarlo(
    target_log_prob_fn=get_target_log_prob_fn(x_train, y_train),

kernel = tfp.mcmc.DualAveragingStepSizeAdaptation(
    num_adaptation_steps=int(0.8 * num_burnin_steps)

chain = sample_chain(

I'll leave link to a google colab notebook that implements the code and demonstrates the problem: https://colab.research.google.com/drive/1oc5czHsGSsi0XC9T7267EfJUR1OdHtlY?usp=sharing

Is there a better way to structure the model parameters such that it better utilizes the parallelization offered by the tfp.mcmc kernels and tfp.mcmc.sample_chain?

ColCarroll commented 2 years ago

It looks like this should be running multiple chains just fine -- is there a reason you think it is not?

A few tips on performance --

  1. You might wrap main (or even just sample_chain) in tf.function(jit_compile=True), as in
    def main(...):


    # inside main()

chain, trace = tf.function(jit_compile=True)(lambda: sample_chain(...))()

The second one looks quite ugly because we're inlining a decorator.

2. Use a step size per chain. setting just `step_size=tf.fill((n_chains, 1), 0.01)` will supply a batch of step sizes, one per chain. Right step size adaptation will pool results across all your chains, potentially leading to faster convergence, but sometimes if a single chain gets stuck, it leads to bad results.

3. `tfp.experimental.mcmc.windowed_adaptive_nuts` and `tfp.experimental.mcmc.windowed_adaptive_hmc` exist to do a lot of what I'm describing for you, but they are easiest to use if you write your model down as a `tfd.JointDistribution` (so maybe this won't work for you).
reneaas commented 2 years ago

Thanks for the response.

The code does run several chains (at least tf.size(target_log_prob_fn(*current_state)) = num_chains), but it degrades performance. For a fixed total number of results, (i.e keep tot_num_results = num_results * num_chains fixed), the performance is significantly better if run with num_chains = 1.

This may very well be a limitation with the model I'm working with. I would suspect it's the way target_log_prob_fn is defined and how the kernel uses it, but how so is beyond the scope of my knowledge of TFP. Therefore, I turned here in hope for helpful suggestions or alternative ways to write the target_log_prob_fn.

I already use XLA compilation on the GPU of the sample chain function which does indeed significantly improve performance.

The step size tip was useful, but I had to store the step sizes as

step_size = [tf.fill(w.shape, 0.001) for w in weights]

to match the sizes of the weights of the neural network.

ColCarroll commented 2 years ago

Possibly helpful as background is this paper (joint work with some of the researchers on the team), and a related competition.

reneaas commented 2 years ago

Thank you, I will take a look.

reneaas commented 2 years ago

I have a follow-up question.

What they have done appears to be mapping a function that runs the Monte Carlo chains on multiple GPUs using jax.pmap. I saw a tutorial for distributed inference with JAX as backend using TFP (https://www.tensorflow.org/probability/examples/Distributed_Inference_with_JAX). I happen to have access to multiple GPUs which makes a similar strategy a viable option for my research. However, porting my code to JAX as a backend is a non-trivial matter as it is currently using TensorFlow. Therefore I wonder, is it possible to distributed sampling with tfp.mcmc.sample_chain over multiple GPUs? The dataset I use it not very large, so the reason I want to use multiple GPUs is not to share data among the GPUs, but rather distribute a copy of the dataset to each GPU and sample independent chains simultaneously to allow for longer trajectories with HMC/NUTS and gather the results in the end.

I can mention that I tried using tf.distribute.MirroredStrategy to distribute the sampling chain across several GPUs, by applying the strategy to the tfp.mcmc.sample_chain inside a function compiled with tf.function(jit_compile=True). This did however lead to a slight performance degradation (which I hypothesise is due to synchonization, but I can't say for sure), where a single GPU would outperform two GPUs generating the same number of total results.