tensorflow / probability

Probabilistic reasoning and statistical analysis in TensorFlow
https://www.tensorflow.org/probability/
Apache License 2.0
4.26k stars 1.1k forks source link

Gathering values from the HMC kernel: How to? #1064

Open MSusik opened 4 years ago

MSusik commented 4 years ago

Hi!

I'm new to TFP and I'm playing with HMC implementation included in TFP. Thanks for the library, it's very nice.

I stumbled upon a scenario where my kernel does some heavy computations that I would like not to redo. The computations would return a 2D matrix where the first dimension is the chains' dimension. For example:

def joint_log_prob(data, param_obs):
    heavy_computation_results = heavy_computation(data, param_obs)
    return tf.reduce_logsumexp(heavy_computation_results, axis=1)

def unnormalized_log_posterior(param_obs):
    return joint_log_prob(data, param_obs)

(...)

res = tfp.mcmc.sample_chain(
        num_results=num_steps,
        num_burnin_steps=num_burnin,
        current_state=inits,
        kernel=kernel,
        trace_fn=trace_fn
    )

Then I would like to store/access the values of heavy_computation_results for each accepted iteration.

I noticed the acceptance info is available in the previous kernel results. Then I need to either add my results to the trace_fn or store it somehow on the side. What would be the right way to achieve this goal?

Thanks!

SiegeLordEx commented 4 years ago

The current solution is to put heavy_computation inside trace_fn as well, and then wrap the sample_chain inside tf.function and then hope that CSE removes the duplicate computation.

MSusik commented 4 years ago

Thanks for the answer!

The current solution is to put heavy_computation inside trace_fn as well

I see. I thought copying the execution of the function outside of the model definition could not be optimised well, will experiment with it :).

brianwa84 commented 4 years ago

It might be beneficial to the CSE engine to tf.function-annotate heavy_computation, so that both calls would appear in the TF graph as "PartionedCall" ops pointing to the same function body.

However, since this is going to traverse the boundaries of a while loop, I don't foresee CSE saving any work. That is, the inputs to heavy_computation as called by the leapfrog integrator of HMC will be different tensors from those passed when called from the trace_fn (which will be conditional on acceptance of the HMC proposal).

If you're up for hacking the code, you could probably hack HMC so that it expects the target_log_prob_fn to return both values, and then stashes heavy_computation_results in a new field you could add to HamiltonianMonteCarloExtraKernelResults. I could imagine this being generalizable if HMC allowed target_log_prob_fn to return auxiliary info (would need to be careful about zeroing gradients to the aux).