Persistent GradientTape hangs with Distribution.sample() #1136

Open mmilosav opened 3 years ago

mmilosav commented 3 years ago

Drawing samples from an instance of tfp.distributions.Distribution in a persistent=True GradientTape context seems to hang:

import tensorflow as tf
import tensorflow_probability as tfp

alpha = tf.Variable([2., 0.5])
beta = tf.Variable([2., 3.])

with tf.GradientTape() as tape:
    sample = tf.random.gamma(shape=(1,), alpha=alpha, beta=beta)


with tf.GradientTape() as tape:
    sample = tfp.distributions.Gamma(concentration=alpha, rate=beta).sample()


with tf.GradientTape(persistent=True) as tape:
    sample = tf.random.gamma(shape=(1,), alpha=alpha, beta=beta)


with tf.GradientTape(persistent=True) as tape:
    sample = tfp.distributions.Gamma(concentration=alpha, rate=beta).sample()

print(sample)  # FIXME: never reachers here!!!

This seems to happen for TF in [2.3.0, 2.3.1, nightly], TFP in [0.11.0, 0.11.1, nightly], Python 3.7, Colab and Mac OS (have not tested release and platform combinations).

brianwa84 commented 3 years ago

Interesting find!

The stack trace seems to suggest a deadlock of some sort.

KeyboardInterrupt                         Traceback (most recent call last)
<ipython-input-1-790e15807a67> in <module>()
     24 with tf.GradientTape(persistent=True) as tape:
---> 25     sample = tfp.distributions.Gamma(concentration=alpha, rate=beta).sample()
     27 print(sample)  # FIXME: never reachers here!!!

29 frames
/usr/local/lib/python3.6/dist-packages/tensorflow_probability/python/distributions/distribution.py in sample(self, sample_shape, seed, name, **kwargs)
    968       samples: a `Tensor` with prepended dimensions `sample_shape`.
    969     """
--> 970     return self._call_sample_n(sample_shape, seed, name, **kwargs)
    972   def _call_log_prob(self, value, name, **kwargs):

/usr/local/lib/python3.6/dist-packages/tensorflow_probability/python/distributions/distribution.py in _call_sample_n(self, sample_shape, seed, name, **kwargs)
    946           sample_shape, 'sample_shape')
    947       samples = self._sample_n(
--> 948           n, seed=seed() if callable(seed) else seed, **kwargs)
    949       batch_event_shape = ps.shape(samples)[1:]
    950       final_shape = ps.concat([sample_shape, batch_event_shape], 0)

/usr/local/lib/python3.6/dist-packages/tensorflow_probability/python/internal/distribution_util.py in _fn(*args, **kwargs)
   1362     @functools.wraps(fn)
   1363     def _fn(*args, **kwargs):
-> 1364       return fn(*args, **kwargs)
   1366     if _fn.__doc__ is None:

/usr/local/lib/python3.6/dist-packages/tensorflow_probability/python/distributions/gamma.py in _sample_n(self, n, seed)
    242         log_rate=(None if self.log_rate is None else
    243                   tf.convert_to_tensor(self.log_rate)),
--> 244         seed=seed)
    246   def _log_prob(self, x, rate=None):

/usr/local/lib/python3.6/dist-packages/tensorflow_probability/python/distributions/gamma.py in random_gamma(shape, concentration, rate, log_rate, seed, log_space)
    654     shape, concentration, rate=None, log_rate=None, seed=None, log_space=False):
    655   return random_gamma_with_runtime(
--> 656       shape, concentration, rate, log_rate, seed, log_space)[0]

/usr/local/lib/python3.6/dist-packages/tensorflow_probability/python/distributions/gamma.py in random_gamma_with_runtime(shape, concentration, rate, log_rate, seed, log_space)
    648   seed = samplers.sanitize_seed(seed, salt='random_gamma')
    649   return _random_gamma_gradient(
--> 650       total_shape, concentration, rate, log_rate, seed, log_space)

/usr/local/lib/python3.6/dist-packages/tensorflow_probability/python/internal/custom_gradient.py in none_wrapper(*args, **kwargs)
     99           return val, vjp_bwd_wrapped
--> 101         return f_wrapped(*trimmed_args, **kwargs)
    103       return none_wrapper

/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/custom_gradient.py in __call__(self, *a, **k)
    260   def __call__(self, *a, **k):
--> 261     return self._d(self._f, a, k)

/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/custom_gradient.py in decorated(wrapped, args, kwargs)
    214     if context.executing_eagerly():
--> 215       return _eager_mode_decorator(wrapped, args, kwargs)
    216     else:
    217       return _graph_mode_decorator(wrapped, args, kwargs)

/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/custom_gradient.py in _eager_mode_decorator(f, args, kwargs)
    436   """Implement custom gradient decorator for eager mode."""
    437   with tape_lib.VariableWatcher() as variable_watcher:
--> 438     result, grad_fn = f(*args, **kwargs)
    439   args = nest.flatten(args)
    440   all_inputs = list(args) + list(kwargs.values())

/usr/local/lib/python3.6/dist-packages/tensorflow_probability/python/internal/custom_gradient.py in f_wrapped(*args, **kwargs)
     87               reconstruct_args.append(args[0])
     88               args = args[1:]
---> 89           val, aux = vjp_fwd(*reconstruct_args, **kwargs)
     91           def vjp_bwd_wrapped(*g):

/usr/local/lib/python3.6/dist-packages/tensorflow_probability/python/distributions/gamma.py in _random_gamma_fwd(shape, concentration, rate, log_rate, seed, log_space)
    535   """Compute output, aux (collaborates with _random_gamma_bwd)."""
    536   samples, impl = _random_gamma_no_gradient(
--> 537       shape, concentration, rate, log_rate, seed, log_space)
    538   return ((samples, impl),
    539           (samples, shape, concentration, rate, log_rate, log_space))

/usr/local/lib/python3.6/dist-packages/tensorflow_probability/python/internal/implementation_selection.py in f_wrapped(*args, **kwargs)
     81     try:
     82       tf.config.run_functions_eagerly(False)
---> 83       return f(*args, **kwargs)
     84     finally:
     85       tf.config.run_functions_eagerly(orig)

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py in __call__(self, *args, **kwds)
    826     tracing_count = self.experimental_get_tracing_count()
    827     with trace.Trace(self._name) as tm:
--> 828       result = self._call(*args, **kwds)
    829       compiler = "xla" if self._experimental_compile else "nonXla"
    830       new_tracing_count = self.experimental_get_tracing_count()

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py in _call(self, *args, **kwds)
    860       # In this case we have not created variables on the first call. So we can
    861       # run the first trace but we should fail if variables are created.
--> 862       results = self._stateful_fn(*args, **kwds)
    863       if self._created_variables:
    864         raise ValueError("Creating variables on a non-first call to a function"

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py in __call__(self, *args, **kwargs)
   2947        filtered_flat_args) = self._maybe_define_function(args, kwargs)
   2948     return graph_function._call_flat(
-> 2949         filtered_flat_args, captured_inputs=graph_function.captured_inputs)  # pylint: disable=protected-access
   2951   @property

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py in _call_flat(self, args, captured_inputs, cancellation_manager)
   1930         possible_gradient_type,
   1931         executing_eagerly)
-> 1932     forward_function, args_with_tangents = forward_backward.forward()
   1933     if executing_eagerly:
   1934       flat_outputs = forward_function.call(

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py in forward(self)
   1453     """Builds or retrieves a forward function for this call."""
   1454     forward_function = self._functions.forward(
-> 1455         self._inference_args, self._input_tangents)
   1456     return forward_function, self._inference_args + self._input_tangents

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py in forward(self, inference_args, input_tangents)
   1205       (self._forward, self._forward_graph, self._backward,
   1206        self._forwardprop_output_indices, self._num_forwardprop_outputs) = (
-> 1207            self._forward_and_backward_functions(inference_args, input_tangents))
   1208     return self._forward

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py in _forward_and_backward_functions(self, inference_args, input_tangents)
   1405       outputs = list(self._func_graph.outputs)
   1406       self._build_functions_for_outputs(
-> 1407           outputs, inference_args, input_tangents)
   1408     (forward_function, forward_graph,
   1409      backward_function, output_indices, num_output_tangents) = (

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py in _build_functions_for_outputs(self, outputs, inference_args, input_tangents)
    908             self._func_graph.inputs,
    909             grad_ys=gradients_wrt_outputs,
--> 910             src_graph=self._func_graph)
    912       if input_tangents:

/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/gradients_util.py in _GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops, gate_gradients, aggregation_method, stop_gradients, unconnected_gradients, src_graph)
    534     xs_set = object_identity.ObjectIdentitySet(xs)
    535     grad_ys = _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops,
--> 536                              gradient_uid)
    538     # The approach we take here is as follows: Create a list of all ops in the

/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/gradients_util.py in _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops, gradient_uid)
    223                              grad_y.dense_shape)))
    224       else:
--> 225         new_grad_ys.append(array_ops.identity(grad_y, name="grad_ys_%d" % i))
    227   return new_grad_ys

/usr/local/lib/python3.6/dist-packages/tensorflow/python/util/dispatch.py in wrapper(*args, **kwargs)
    199     """Call target, and fall back on dispatchers if there is a TypeError."""
    200     try:
--> 201       return target(*args, **kwargs)
    202     except (TypeError, ValueError):
    203       # Note: convert_to_eager_tensor currently raises a ValueError, not a

/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/array_ops.py in identity(input, name)
    285     # variables. Variables have correct handle data when graph building.
    286     input = ops.convert_to_tensor(input)
--> 287   ret = gen_array_ops.identity(input, name=name)
    288   # Propagate handle data for happier shape inference for resource variables.
    289   if hasattr(input, "_handle_data"):

/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/gen_array_ops.py in identity(input, name)
   3941   # Add nodes to the TensorFlow graph.
   3942   _, _, _op, _outputs = _op_def_library._apply_op_helper(
-> 3943         "Identity", input=input, name=name)
   3944   _result = _outputs[:]
   3945   if _execute.must_record_gradient():

/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/op_def_library.py in _apply_op_helper(op_type_name, name, **keywords)
    748       op = g._create_op_internal(op_type_name, inputs, dtypes=None,
    749                                  name=scope, input_types=input_types,
--> 750                                  attrs=attr_protos, op_def=op_def)
    752     # `outputs` is returned as a separate return value so that the output

/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/func_graph.py in _create_op_internal(self, op_type, inputs, dtypes, input_types, name, attrs, op_def, compute_device)
    590     return super(FuncGraph, self)._create_op_internal(  # pylint: disable=protected-access
    591         op_type, captured_inputs, dtypes, input_types, name, attrs, op_def,
--> 592         compute_device)
    594   def capture(self, tensor, name=None, shape=None):

/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py in _create_op_internal(self, op_type, inputs, dtypes, input_types, name, attrs, op_def, compute_device)
   3525     # _create_op_helper mutates the new Operation. `_mutation_lock` ensures a
   3526     # Session.run call cannot occur between creating and mutating the op.
-> 3527     with self._mutation_lock():
   3528       ret = Operation(
   3529           node_def,

/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py in _mutation_lock(self)
   5186     See the comment for self._group_lock for more info.
   5187     """
-> 5188     return self._group_lock.group(_MUTATION_LOCK_GROUP)
   5190   def _session_run_lock(self):


@saxenasaurabh do you think you could route this appropriately?

It seems to reproduce with tf.function as well, e.g. sample = tf.function(tfp.distributions.Gamma(concentration=alpha, rate=beta).sample)()

In the meantime if you need a workaround to avoid retracing the gradient, you can write:

def gamma_sample_and_grad(alpha, beta):
  return tfp.math.value_and_gradient(lambda a,b: tfp.distributions.Gamma(concentration=a, rate=b).sample(), (alpha, beta))
saxenasaurabh commented 3 years ago

Very interesting bug. @allenlavoie any ideas?

cc: @rohan100jain

brianwa84 commented 3 years ago

The TFP sampler has a custom gradient, presumably that's involved somehow.

allenlavoie commented 3 years ago

My guess is that (a) there's a while_loop involved, and (b) the while_loop keeps adding new side outputs to the forward pass when tf.function tries to adapt it to work with the tape, leading to an infinite loop. I have a change pending that will probably make it better, but if not will at least start printing warnings.

brianwa84 commented 3 years ago

There is a while loop in the forward computation, but the backprop of the outputs of the while wrt the inputs should be overridden by custom_gradient, and not have a while loop.

allenlavoie commented 3 years ago

Right now we don't thoroughly apply stop_gradient to everything in the forward pass of custom_gradient, we just override the result. It's possible the while_loop gradient still gets built, e.g. if the custom_gradient backward function captures a tensor from the forward pass that isn't an output (and so for higher-order gradients would either need to throw a "no gradient defined" error or build non-custom gradients for forward pass tensors).

allenlavoie commented 3 years ago

I tried the snippet with https://github.com/tensorflow/tensorflow/commit/07b75ffa453b1ec9189371244d21b6684503340b (the pending change I mentioned), and it no longer goes into an infinite loop. The change should be in tomorrow's nightly.

Brian, is the forward pass of the custom gradient being autodiffed a concern here? Otherwise this can probably be closed.

Darkhunter9 commented 3 years ago

Same here. Specifically, I'm using a customized tfp distribution (Von Mises Fisher distribution), and there is a tensorflow.python.ops.control_flow_ops.while_loop in my sampling when the dimension is not 3. As a comparison, when the dimension is 3, there's no while loop in the sampling and the use of gradient tape is fine. Any suggestion? Thank you

    def _sample_n(self, n, seed=None):
        shape = array_ops.concat([[n], self.batch_shape_tensor()], 0)
        w = control_flow_ops.cond(gen_math_ops.equal(self.__m, 3),
                                  lambda: self.__sample_w3(n, seed),
                                  lambda: self.__sample_w_rej(n, seed))

        v = nn_impl.l2_normalize(array_ops.transpose(
            array_ops.transpose(random_ops.random_normal(shape, dtype=self.dtype, seed=seed))[1:]), axis=-1)

        x = array_ops.concat((w, math_ops.sqrt(1 - w ** 2) * v), axis=-1)
        z = self.__householder_rotation(x)

        return z

    def __sample_w3(self, n, seed):
        shape = array_ops.concat(([n], self.batch_shape_tensor()[:-1], [1]), 0)
        u = random_ops.random_uniform(shape, dtype=self.dtype, seed=seed)
        self.__w = 1 + math_ops.reduce_logsumexp([math_ops.log(u), math_ops.log(1 - u) - 2 * self.scale], axis=0) / self.scale
        return self.__w

    def __sample_w_rej(self, n, seed):
        c = math_ops.sqrt((4 * (self.scale ** 2)) + (self.__mf - 1) ** 2)
        b_true = (-2 * self.scale + c) / (self.__mf - 1)

        # using Taylor approximation with a smooth swift from 10 < scale < 11
        # to avoid numerical errors for large scale
        b_app = (self.__mf - 1) / (4 * self.scale)
        s = gen_math_ops.minimum(gen_math_ops.maximum(0., self.scale - 10), 1.)
        b = b_app * s + b_true * (1 - s)

        a = (self.__mf - 1 + 2 * self.scale + c) / 4
        d = (4 * a * b) / (1 + b) - (self.__mf - 1) * math_ops.log(self.__mf - 1)

        self.__b, (self.__e, self.__w) = b, self.__while_loop(b, a, d, n, seed)
        return self.__w

    def __while_loop(self, b, a, d, n, seed):
        def __cond(w, e, bool_mask, b, a, d):
            return math_ops.reduce_any(bool_mask)

        def __body(w_, e_, bool_mask, b, a, d):
            e = math_ops.cast(tfd.Beta((self.__mf - 1) / 2, (self.__mf - 1) / 2).sample(
                shape, seed=seed), dtype=self.dtype)

            u = random_ops.random_uniform(shape, dtype=self.dtype, seed=seed)

            w = (1 - (1 + b) * e) / (1 - (1 - b) * e)
            t = (2 * a * b) / (1 - (1 - b) * e)

            accept = gen_math_ops.greater(((self.__mf - 1) * math_ops.log(t) - t + d), math_ops.log(u))
            reject = gen_math_ops.logical_not(accept)

            w_ = array_ops.where(gen_math_ops.logical_and(bool_mask, accept), w, w_)
            e_ = array_ops.where(gen_math_ops.logical_and(bool_mask, accept), e, e_)
            bool_mask = array_ops.where(gen_math_ops.logical_and(bool_mask, accept), reject, bool_mask)

            return w_, e_, bool_mask, b, a, d

        shape = array_ops.concat([[n], self.batch_shape_tensor()[:-1], [1]], 0)
        b, a, d = [gen_array_ops.tile(array_ops.expand_dims(e, axis=0), [n] + [1] * len(e.shape)) for e in (b, a, d)]

        w, e, bool_mask, b, a, d = control_flow_ops.while_loop(__cond, __body,
                                                               [array_ops.zeros_like(b, dtype=self.dtype),
                                                                array_ops.zeros_like(b, dtype=self.dtype),
                                                                array_ops.ones_like(b, dtypes.bool),
                                                                b, a, d])

        return e, w
allenlavoie commented 3 years ago

@Darkhunter9 I'd suggest using a TF nightly for the fix. It'll be in TF 2.5 when that comes out.

Darkhunter9 commented 3 years ago

@allenlavoie Thanks for the reply! Currently on the machine I deploy training, only CUDA 10.1 is available. Is CUDA 11 necessary for TF >= 2.4?