tensorflow / probability

Probabilistic reasoning and statistical analysis in TensorFlow
https://www.tensorflow.org/probability/
Apache License 2.0
4.26k stars 1.1k forks source link

How to batch Hamiltonian Monte Carlo (HMC) for sampling from multiple chains #1093

Closed MatthiasKohl closed 4 years ago

MatthiasKohl commented 4 years ago

This is more of a question, but when using HMC to infer the parameters of a BNN - similar to what is done in this issue - how do you run multiple chains at a time.

Given what's in this paper it should be "simple" with TFP to do batch-wise HMC in order to run multiple chains at a time. However, it's not obvious to me what the outputs target_log_prob_fn for HMC should look like in this case and how to apply a batch of BNNs at the same time, given the usual TF ops.

There is an example for this with RandomWalkMetropolis, but it doesn't use any custom log-prob function, so it's not obvious what that function actually returns.

My current code for target_log_prob_fn looks something like this:

def bdnn(x, p):
    nf = num_features
    nt = first_layer_size

    # Unpack model parameters
    w1 = tf.reshape(p[:nf*nt], [nf, nt])
    b1 = p[nf*nt:nf*nt+1]
    w2 = tf.reshape(p[nf*nt+1:nf*nt+nt+1], [nt, 1])
    b2 = p[nf*nt+nt+1:nf*nt+nt+2]

    # Build layers
    x = tf.tanh(tf.nn.xw_plus_b(x, w1, b1))
    x = tf.nn.xw_plus_b(x, w2, b2)
    return tf.squeeze(x)

def model_log_prob(p):
    # Parameters of distributions
    prior_scale = 200
    studentT_scale = 100

    # Prior probability distributions on model parameters
    rv_p = tfd.Independent(tfd.Normal(loc=0. * tf.ones(shape=[num_chains, num_model_parameters], dtype=tf.float32),
                                      scale=prior_scale * tf.ones(shape=[num_chains, num_model_parameters], dtype=tf.float32)),
                           reinterpreted_batch_ndims=1)

    # Likelihood
    alpha_bp_estimate = bdnn(features, p)
    rv_observed = tfd.StudentT(df=2.2, loc=alpha_bp_estimate, scale=studentT_scale)

    # Sum of logs
    return (rv_p.log_prob(p) + tf.reduce_sum(rv_observed.log_prob(return)))

Should I just return a vector of log-probabilities, one for each chain? Thank you

junpenglao commented 4 years ago

Should I just return a vector of log-probabilities, one for each chain?

Yes - TFP mcmc kernels follows the API contract that it returns the same number of chain (ie batch) as the target_log_prob evaluated on the inital state:

init_state = ...  # the initial state of all your mcmc chains, also the arg to be input in tfp.mcmc.sample_chain
num_chains = tf.shape(model_log_prob(*ini_state))  # note that if the output shape is more than 1d eg [5, 3],
                                                   # you will get [5, 3] chains
ColCarroll commented 4 years ago

To pipe in, the tf.reduce_sum should almost always specify the axis you reduce over (in the last line of your model_log_p function. Very likely, the axis should be -1 (it depends on the shape of the variable return), so that the chain (batch) dimensions in the front are not summed over.

MatthiasKohl commented 4 years ago

Thank you very much for your help! Indeed, simply returning a batch of losses did the trick for me, and as @ColCarroll mentioned, I changed the last sum to axis=-1. I haven't seen this feature in the documentation so maybe adding an example or additional doc to the kernels could be worth it? It could have been an oversight on my part as well, of course.

atsyplikhin commented 3 years ago

@MatthiasKohl - may I ask you to post your updated code? I also wonder how much slower it runs for several chains. Thanks.