Open MSusik opened 4 years ago
The current solution is to put heavy_computation
inside trace_fn
as well, and then wrap the sample_chain
inside tf.function
and then hope that CSE removes the duplicate computation.
Thanks for the answer!
The current solution is to put heavy_computation inside trace_fn as well
I see. I thought copying the execution of the function outside of the model definition could not be optimised well, will experiment with it :).
It might be beneficial to the CSE engine to tf.function
-annotate heavy_computation
, so that both calls would appear in the TF graph as "PartionedCall" ops pointing to the same function body.
However, since this is going to traverse the boundaries of a while loop, I don't foresee CSE saving any work. That is, the inputs to heavy_computation as called by the leapfrog integrator of HMC will be different tensors from those passed when called from the trace_fn
(which will be conditional on acceptance of the HMC proposal).
If you're up for hacking the code, you could probably hack HMC so that it expects the target_log_prob_fn to return both values, and then stashes heavy_computation_results
in a new field you could add to HamiltonianMonteCarloExtraKernelResults
. I could imagine this being generalizable if HMC allowed target_log_prob_fn
to return auxiliary info (would need to be careful about zeroing gradients to the aux).
Hi!
I'm new to TFP and I'm playing with HMC implementation included in TFP. Thanks for the library, it's very nice.
I stumbled upon a scenario where my kernel does some heavy computations that I would like not to redo. The computations would return a 2D matrix where the first dimension is the chains' dimension. For example:
Then I would like to store/access the values of
heavy_computation_results
for each accepted iteration.I noticed the acceptance info is available in the previous kernel results. Then I need to either add my results to the trace_fn or store it somehow on the side. What would be the right way to achieve this goal?
Thanks!