ponder-lab / Hybridize-Functions-Refactoring

Refactorings for optimizing imperative TensorFlow clients for greater efficiency.
Eclipse Public License 2.0
0 stars 0 forks source link

Handle "leaky" tensors #281

Open khatchad opened 1 year ago

khatchad commented 1 year ago

Consider the following code:

# From  https://www.tensorflow.org/guide/function#all_outputs_of_a_tffunction_must_be_return_values

import tensorflow as tf
from nose.tools import assert_raises

@tf.function
def leaky_function(a):
  global x
  x = a + 1  # Bad - leaks local tensor
  return x  # Good - uses local tensor

correct_a = leaky_function(tf.constant(1))

print(correct_a.numpy())  # Good - value obtained from function's returns
try:
  x.numpy()  # Bad - tensor leaked from inside the function, cannot be used here
except AttributeError as expected:
  print(expected)

def captures_leaked_tensor(b):
  b += x  # Bad - `x` is leaked from `leaky_function`
  return b

with assert_raises(TypeError):
  captures_leaked_tensor(tf.constant(2))

Here, capture_leaked_tensor() doesn't explicitly have Python side-effects, but it reads from a global variable that was written to by a function that does have Python side-effects. Thus, captures_leaked_tensor should not be converted to hybrid.

Expected Behavior

  1. The function capture_leaked_tensor() should not be hybridized.
  2. If capture_leaked_tensor() is already hybrid, we should warn.