Closed MatthiasKohl closed 4 years ago
Should I just return a vector of log-probabilities, one for each chain?
Yes - TFP mcmc kernels follows the API contract that it returns the same number of chain (ie batch) as the target_log_prob evaluated on the inital state:
init_state = ... # the initial state of all your mcmc chains, also the arg to be input in tfp.mcmc.sample_chain
num_chains = tf.shape(model_log_prob(*ini_state)) # note that if the output shape is more than 1d eg [5, 3],
# you will get [5, 3] chains
To pipe in, the tf.reduce_sum
should almost always specify the axis you reduce over (in the last line of your model_log_p
function. Very likely, the axis should be -1 (it depends on the shape of the variable return
), so that the chain (batch) dimensions in the front are not summed over.
Thank you very much for your help! Indeed, simply returning a batch of losses did the trick for me, and as @ColCarroll mentioned, I changed the last sum to axis=-1. I haven't seen this feature in the documentation so maybe adding an example or additional doc to the kernels could be worth it? It could have been an oversight on my part as well, of course.
@MatthiasKohl - may I ask you to post your updated code? I also wonder how much slower it runs for several chains. Thanks.
This is more of a question, but when using HMC to infer the parameters of a BNN - similar to what is done in this issue - how do you run multiple chains at a time.
Given what's in this paper it should be "simple" with TFP to do batch-wise HMC in order to run multiple chains at a time. However, it's not obvious to me what the outputs
target_log_prob_fn
for HMC should look like in this case and how to apply a batch of BNNs at the same time, given the usual TF ops.There is an example for this with RandomWalkMetropolis, but it doesn't use any custom log-prob function, so it's not obvious what that function actually returns.
My current code for
target_log_prob_fn
looks something like this:Should I just return a vector of log-probabilities, one for each chain? Thank you