tensorflow / probability

Probabilistic reasoning and statistical analysis in TensorFlow
https://www.tensorflow.org/probability/
Apache License 2.0
4.25k stars 1.1k forks source link

Excessive retracing in autobatched joint distributions #1076

Open davmre opened 4 years ago

davmre commented 4 years ago

Calling .sample or .log_prob repeatedly on an autobatched joint distribution (as you would in, say, an eager-mode optimization loop) prints the warning

WARNING:tensorflow:5 out of the last 5 calls to <function pfor.<locals>.f at 0x7f16b7927e18> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/tutorials/customization/performance#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for more details.

Reproduction: https://colab.research.google.com/drive/15RRL7pjgORFzcjWH2i74BQte61EUtBGc?usp=sharing

Autobatched JDs should probably cache the pfor'd tf.functions to avoid this.

ksachdeva commented 4 years ago

Hi @davmre

I see these warnings in regular Joint Distributions as well.

See for example (Search tracing on the web page you will many instances of it) https://ksachdeva.github.io/rethinking-tensorflow-probability/05_the_many_variables_and_the_spurious_waffles.html

or for that matter any notebook at https://ksachdeva.github.io/rethinking-tensorflow-probability/

I had tried almost all the suggestions given by the warning message i.e. pass tensors and not objects etc but nothing worked.

It has been like this for quite some time now and is the case in general i.e. not specific to AutoBatched

If you notice something that I am doing inappropriate that can help get rid of these warning it would be really great.

Regards Kapil

davmre commented 4 years ago

@ksachdeva: I don't think the warnings you're seeing are coming from JointDistribution (or from inside TFP at all). Your run_hmc_chain method is wrapped with tf.function, and the warnings are telling you that that method is being retraced.

In general, any tf.function that takes non-Tensor arguments will be retraced whenever any of those objects changes. In your case, run_hmc_chain takes a target_log_prob_fn and bijectors, so it'll be retraced whenever those change. Since those args are created inside of sample_posterior, you'll get retracing whenever sample_posterior is called.

In cases where you're calling sample_posterior with a genuinely different model, there's no way to avoid retracing, because different models will in general be traced to different graphs. But if you're sampling from the same model multiple times, you can probably restructure your code to avoid retracing. A general rule of thumb is to try to put tf.function wrapping around pieces where you expect to reuse the same computation graph multiple times, e.g., the target_log_prob_fn for a particular model.

ksachdeva commented 4 years ago

Thanks Dave.

Clearly I was misusing the tf.function and your explanation has made me understand it better. The key part I was not paying attention to was - decorating a function using tf.function makes tensorflow cache it and if it invoked multiple times (with different objects/refs) it is going to trace and then cache it again and hence the warning.

in cases where you're calling sample_posterior with a genuinely different model, there's no way to avoid retracing,

Indeed it is the case as I use sample_posterior (a utility method for me) and pass the different models. A notebook (rethinking) contains many models and hence this structure. Not sure how to avoid it in my case.