greta-dev / greta

simple and scalable statistical modelling in R
https://greta-stats.org
Other
524 stars 63 forks source link

Explore retracing warning in TF2 #546

Open njtierney opened 2 years ago

njtierney commented 2 years ago

In the TF2 branch - #543

We were getting a warning to the effect of "TF functions are being retraced 6 times, this is a problem". However I can't seem to replicate the message. I feel like this was happening at one level of debugging, and the error is being swallowed up somewhere. Anyway, just documenting this here, in case it comes up again.

njtierney commented 1 year ago

Example of message:

WARNING:tensorflow:6 out of the last 6 calls to <function pfor.<locals>.f at 0x2e881c430> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
njtierney commented 1 year ago

One bandaid solution to this is to suppress these warnings for the current release and deal with them at a later date - I will reference this commit with my current approach, and also flag the relevant parts of code with

## TF1/2 retracing

To help make them easier to find

njtierney commented 3 months ago
njtierney commented 3 months ago

I can't get it to reproduce via reprex, but this code gives me a retracing warning

library(greta)
# in opt
sd <- runif(5)
x <- rnorm(5, 2, 0.1)
z1 <- variable(dim = 1)
z2 <- variable(dim = 1)
z3 <- variable(dim = 1)
z4 <- variable(dim = 1)
z5 <- variable(dim = 1)
z <- c(z1, z2, z3, z4, z5)
distribution(x) <- normal(z, sd)

m <- model(z1, z2, z3, z4, z5)
o <- opt(m, hessian = TRUE)
WARNING:tensorflow:6 out of the last 6 calls to <function pfor.<locals>.f at 0x29cfff9c0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
njtierney commented 3 months ago

In reading up on this in

https://tensorflow.rstudio.com/guides/tensorflow/intro_to_graphs.html

We might want to explore tf_function(jit_compile = TRUE)

njtierney commented 3 months ago

There are some more hints to control retracing here: https://www.tensorflow.org/guide/function#controlling_retracing

Currently the main places we are using tf_function are in

define_tf_trace_values_batch = function(){
      self$tf_trace_values_batch <- tensorflow::tf_function(
        f = self$define_trace_values_batch
      )
    },

    define_tf_log_prob_function = function(){
      self$tf_log_prob_function <- tensorflow::tf_function(
        # TF1/2 check
        # need to check in on all cases of `tensorflow::tf_function()`
        # as we are getting lots of warnings about retracting
        f = self$generate_log_prob_function()
      )
    },

One of the proposed solutions to retracing is using input_signature - which we do in define_tf_evaluate_sample_batch.

However, it isn't clear to me how to express an input signature for the two instances above, as they are functions that don't take arguments. Another suggestion is to use unknown dimensions for flexibility

The example they give is:

@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def g(x):
  print('Tracing with', x)
  return x

# No retrace!
print(g(tf.constant([1, 2, 3])))
print(g(tf.constant([1, 2, 3, 4, 5])))

perhaps then we just give that input signature?

njtierney commented 3 months ago

That did not work.

Unfortunately neither did setting reduce_retracing to TRUE https://www.tensorflow.org/guide/function#controlling_retracing