Open reneaas opened 2 years ago
It looks like this should be running multiple chains just fine -- is there a reason you think it is not?
A few tips on performance --
main
(or even just sample_chain
) in tf.function(jit_compile=True)
, as in
@tf.function(jit_compile=True)
def main(...):
...
or
# inside main()
chain, trace = tf.function(jit_compile=True)(lambda: sample_chain(...))()
The second one looks quite ugly because we're inlining a decorator.
2. Use a step size per chain. setting just `step_size=tf.fill((n_chains, 1), 0.01)` will supply a batch of step sizes, one per chain. Right step size adaptation will pool results across all your chains, potentially leading to faster convergence, but sometimes if a single chain gets stuck, it leads to bad results.
3. `tfp.experimental.mcmc.windowed_adaptive_nuts` and `tfp.experimental.mcmc.windowed_adaptive_hmc` exist to do a lot of what I'm describing for you, but they are easiest to use if you write your model down as a `tfd.JointDistribution` (so maybe this won't work for you).
Thanks for the response.
The code does run several chains (at least tf.size(target_log_prob_fn(*current_state)) = num_chains
), but it degrades performance. For a fixed total number of results, (i.e keep tot_num_results = num_results * num_chains
fixed), the performance is significantly better if run with num_chains = 1
.
This may very well be a limitation with the model I'm working with. I would suspect it's the way target_log_prob_fn
is defined and how the kernel uses it, but how so is beyond the scope of my knowledge of TFP. Therefore, I turned here in hope for helpful suggestions or alternative ways to write the target_log_prob_fn
.
I already use XLA compilation on the GPU of the sample chain function which does indeed significantly improve performance.
The step size tip was useful, but I had to store the step sizes as
step_size = [tf.fill(w.shape, 0.001) for w in weights]
to match the sizes of the weights of the neural network.
Possibly helpful as background is this paper (joint work with some of the researchers on the team), and a related competition.
Thank you, I will take a look.
I have a follow-up question.
What they have done appears to be mapping a function that runs the Monte Carlo chains on multiple GPUs using jax.pmap
.
I saw a tutorial for distributed inference with JAX as backend using TFP (https://www.tensorflow.org/probability/examples/Distributed_Inference_with_JAX). I happen to have access to multiple GPUs which makes a similar strategy a viable option for my research. However, porting my code to JAX as a backend is a non-trivial matter as it is currently using TensorFlow.
Therefore I wonder, is it possible to distributed sampling with tfp.mcmc.sample_chain
over multiple GPUs? The dataset I use it not very large, so the reason I want to use multiple GPUs is not to share data among the GPUs, but rather distribute a copy of the dataset to each GPU and sample independent chains simultaneously to allow for longer trajectories with HMC/NUTS and gather the results in the end.
I can mention that I tried using tf.distribute.MirroredStrategy
to distribute the sampling chain across several GPUs, by applying the strategy to the tfp.mcmc.sample_chain
inside a function compiled with tf.function(jit_compile=True)
. This did however lead to a slight performance degradation (which I hypothesise is due to synchonization, but I can't say for sure), where a single GPU would outperform two GPUs generating the same number of total results.
Hi,
I'm working on a master's thesis where we want to sample the exact posterior of a bayesian neural network using HMC and the No-U-Turn sampler in regression tasks. The current code implementation does not optimally utilize the fact that several chains can be run simultaneously when calling tfp.mcmc.sample_chain.
I'll walk through the code implementation. First, the weights of the network is stored in a list as
The kernels have shape [num_chains, n, m] and the biases have shape [num_chains, m] to allow for several chains run simultaneously as per the docs of Tensorflow Probability.
The log prior is defined as
The log likelihood is defined as
And the target log probability function is defined by
Thus, given training features
x
of shape [num_points, num_features] and training targetsy
of shape [num_points, num_outputs], we can extract the target log probability function asWhile this code provide adequate results in terms of a proper non-linear regression model, it does not run well when
num_chains > 1
. In fact, the casenum_chains = 1
withnum_results = 100
runs significantly faster thannum_chains = 10
withnum_results = 10
, even though the produce the exact same number of results.As an example of model, we can create it with the following function:
And run the chain with adaptive HMC like so:
I'll leave link to a google colab notebook that implements the code and demonstrates the problem: https://colab.research.google.com/drive/1oc5czHsGSsi0XC9T7267EfJUR1OdHtlY?usp=sharing
Is there a better way to structure the model parameters such that it better utilizes the parallelization offered by the tfp.mcmc kernels and tfp.mcmc.sample_chain?