Open davmre opened 4 years ago
Hi @davmre
I see these warnings in regular Joint Distributions as well.
See for example (Search tracing
on the web page you will many instances of it)
https://ksachdeva.github.io/rethinking-tensorflow-probability/05_the_many_variables_and_the_spurious_waffles.html
or for that matter any notebook at https://ksachdeva.github.io/rethinking-tensorflow-probability/
I had tried almost all the suggestions given by the warning message i.e. pass tensors and not objects etc but nothing worked.
It has been like this for quite some time now and is the case in general i.e. not specific to AutoBatched
If you notice something that I am doing inappropriate that can help get rid of these warning it would be really great.
Regards Kapil
@ksachdeva: I don't think the warnings you're seeing are coming from JointDistribution
(or from inside TFP at all). Your run_hmc_chain
method is wrapped with tf.function
, and the warnings are telling you that that method is being retraced.
In general, any tf.function
that takes non-Tensor
arguments will be retraced whenever any of those objects changes. In your case, run_hmc_chain
takes a target_log_prob_fn
and bijectors
, so it'll be retraced whenever those change. Since those args are created inside of sample_posterior
, you'll get retracing whenever sample_posterior
is called.
In cases where you're calling sample_posterior
with a genuinely different model, there's no way to avoid retracing, because different models will in general be traced to different graphs. But if you're sampling from the same model multiple times, you can probably restructure your code to avoid retracing. A general rule of thumb is to try to put tf.function
wrapping around pieces where you expect to reuse the same computation graph multiple times, e.g., the target_log_prob_fn
for a particular model.
Thanks Dave.
Clearly I was misusing the tf.function and your explanation has made me understand it better. The key part I was not paying attention to was - decorating a function using tf.function makes tensorflow cache it and if it invoked multiple times (with different objects/refs) it is going to trace and then cache it again and hence the warning.
in cases where you're calling sample_posterior with a genuinely different model, there's no way to avoid retracing,
Indeed it is the case as I use sample_posterior (a utility method for me) and pass the different models. A notebook (rethinking) contains many models and hence this structure. Not sure how to avoid it in my case.
Calling
.sample
or.log_prob
repeatedly on an autobatched joint distribution (as you would in, say, an eager-mode optimization loop) prints the warningWARNING:tensorflow:5 out of the last 5 calls to <function pfor.<locals>.f at 0x7f16b7927e18> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/tutorials/customization/performance#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for more details.
Reproduction: https://colab.research.google.com/drive/15RRL7pjgORFzcjWH2i74BQte61EUtBGc?usp=sharing
Autobatched JDs should probably cache the pfor'd
tf.function
s to avoid this.