Open mmilosav opened 4 years ago
Interesting find!
The stack trace seems to suggest a deadlock of some sort.
KeyboardInterrupt Traceback (most recent call last)
<ipython-input-1-790e15807a67> in <module>()
23
24 with tf.GradientTape(persistent=True) as tape:
---> 25 sample = tfp.distributions.Gamma(concentration=alpha, rate=beta).sample()
26
27 print(sample) # FIXME: never reachers here!!!
29 frames
/usr/local/lib/python3.6/dist-packages/tensorflow_probability/python/distributions/distribution.py in sample(self, sample_shape, seed, name, **kwargs)
968 samples: a `Tensor` with prepended dimensions `sample_shape`.
969 """
--> 970 return self._call_sample_n(sample_shape, seed, name, **kwargs)
971
972 def _call_log_prob(self, value, name, **kwargs):
/usr/local/lib/python3.6/dist-packages/tensorflow_probability/python/distributions/distribution.py in _call_sample_n(self, sample_shape, seed, name, **kwargs)
946 sample_shape, 'sample_shape')
947 samples = self._sample_n(
--> 948 n, seed=seed() if callable(seed) else seed, **kwargs)
949 batch_event_shape = ps.shape(samples)[1:]
950 final_shape = ps.concat([sample_shape, batch_event_shape], 0)
/usr/local/lib/python3.6/dist-packages/tensorflow_probability/python/internal/distribution_util.py in _fn(*args, **kwargs)
1362 @functools.wraps(fn)
1363 def _fn(*args, **kwargs):
-> 1364 return fn(*args, **kwargs)
1365
1366 if _fn.__doc__ is None:
/usr/local/lib/python3.6/dist-packages/tensorflow_probability/python/distributions/gamma.py in _sample_n(self, n, seed)
242 log_rate=(None if self.log_rate is None else
243 tf.convert_to_tensor(self.log_rate)),
--> 244 seed=seed)
245
246 def _log_prob(self, x, rate=None):
/usr/local/lib/python3.6/dist-packages/tensorflow_probability/python/distributions/gamma.py in random_gamma(shape, concentration, rate, log_rate, seed, log_space)
654 shape, concentration, rate=None, log_rate=None, seed=None, log_space=False):
655 return random_gamma_with_runtime(
--> 656 shape, concentration, rate, log_rate, seed, log_space)[0]
657
658
/usr/local/lib/python3.6/dist-packages/tensorflow_probability/python/distributions/gamma.py in random_gamma_with_runtime(shape, concentration, rate, log_rate, seed, log_space)
648 seed = samplers.sanitize_seed(seed, salt='random_gamma')
649 return _random_gamma_gradient(
--> 650 total_shape, concentration, rate, log_rate, seed, log_space)
651
652
/usr/local/lib/python3.6/dist-packages/tensorflow_probability/python/internal/custom_gradient.py in none_wrapper(*args, **kwargs)
99 return val, vjp_bwd_wrapped
100
--> 101 return f_wrapped(*trimmed_args, **kwargs)
102
103 return none_wrapper
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/custom_gradient.py in __call__(self, *a, **k)
259
260 def __call__(self, *a, **k):
--> 261 return self._d(self._f, a, k)
262
263
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/custom_gradient.py in decorated(wrapped, args, kwargs)
213
214 if context.executing_eagerly():
--> 215 return _eager_mode_decorator(wrapped, args, kwargs)
216 else:
217 return _graph_mode_decorator(wrapped, args, kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/custom_gradient.py in _eager_mode_decorator(f, args, kwargs)
436 """Implement custom gradient decorator for eager mode."""
437 with tape_lib.VariableWatcher() as variable_watcher:
--> 438 result, grad_fn = f(*args, **kwargs)
439 args = nest.flatten(args)
440 all_inputs = list(args) + list(kwargs.values())
/usr/local/lib/python3.6/dist-packages/tensorflow_probability/python/internal/custom_gradient.py in f_wrapped(*args, **kwargs)
87 reconstruct_args.append(args[0])
88 args = args[1:]
---> 89 val, aux = vjp_fwd(*reconstruct_args, **kwargs)
90
91 def vjp_bwd_wrapped(*g):
/usr/local/lib/python3.6/dist-packages/tensorflow_probability/python/distributions/gamma.py in _random_gamma_fwd(shape, concentration, rate, log_rate, seed, log_space)
535 """Compute output, aux (collaborates with _random_gamma_bwd)."""
536 samples, impl = _random_gamma_no_gradient(
--> 537 shape, concentration, rate, log_rate, seed, log_space)
538 return ((samples, impl),
539 (samples, shape, concentration, rate, log_rate, log_space))
/usr/local/lib/python3.6/dist-packages/tensorflow_probability/python/internal/implementation_selection.py in f_wrapped(*args, **kwargs)
81 try:
82 tf.config.run_functions_eagerly(False)
---> 83 return f(*args, **kwargs)
84 finally:
85 tf.config.run_functions_eagerly(orig)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py in __call__(self, *args, **kwds)
826 tracing_count = self.experimental_get_tracing_count()
827 with trace.Trace(self._name) as tm:
--> 828 result = self._call(*args, **kwds)
829 compiler = "xla" if self._experimental_compile else "nonXla"
830 new_tracing_count = self.experimental_get_tracing_count()
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py in _call(self, *args, **kwds)
860 # In this case we have not created variables on the first call. So we can
861 # run the first trace but we should fail if variables are created.
--> 862 results = self._stateful_fn(*args, **kwds)
863 if self._created_variables:
864 raise ValueError("Creating variables on a non-first call to a function"
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py in __call__(self, *args, **kwargs)
2947 filtered_flat_args) = self._maybe_define_function(args, kwargs)
2948 return graph_function._call_flat(
-> 2949 filtered_flat_args, captured_inputs=graph_function.captured_inputs) # pylint: disable=protected-access
2950
2951 @property
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py in _call_flat(self, args, captured_inputs, cancellation_manager)
1930 possible_gradient_type,
1931 executing_eagerly)
-> 1932 forward_function, args_with_tangents = forward_backward.forward()
1933 if executing_eagerly:
1934 flat_outputs = forward_function.call(
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py in forward(self)
1453 """Builds or retrieves a forward function for this call."""
1454 forward_function = self._functions.forward(
-> 1455 self._inference_args, self._input_tangents)
1456 return forward_function, self._inference_args + self._input_tangents
1457
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py in forward(self, inference_args, input_tangents)
1205 (self._forward, self._forward_graph, self._backward,
1206 self._forwardprop_output_indices, self._num_forwardprop_outputs) = (
-> 1207 self._forward_and_backward_functions(inference_args, input_tangents))
1208 return self._forward
1209
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py in _forward_and_backward_functions(self, inference_args, input_tangents)
1405 outputs = list(self._func_graph.outputs)
1406 self._build_functions_for_outputs(
-> 1407 outputs, inference_args, input_tangents)
1408 (forward_function, forward_graph,
1409 backward_function, output_indices, num_output_tangents) = (
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py in _build_functions_for_outputs(self, outputs, inference_args, input_tangents)
908 self._func_graph.inputs,
909 grad_ys=gradients_wrt_outputs,
--> 910 src_graph=self._func_graph)
911
912 if input_tangents:
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/gradients_util.py in _GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops, gate_gradients, aggregation_method, stop_gradients, unconnected_gradients, src_graph)
534 xs_set = object_identity.ObjectIdentitySet(xs)
535 grad_ys = _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops,
--> 536 gradient_uid)
537
538 # The approach we take here is as follows: Create a list of all ops in the
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/gradients_util.py in _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops, gradient_uid)
223 grad_y.dense_shape)))
224 else:
--> 225 new_grad_ys.append(array_ops.identity(grad_y, name="grad_ys_%d" % i))
226
227 return new_grad_ys
/usr/local/lib/python3.6/dist-packages/tensorflow/python/util/dispatch.py in wrapper(*args, **kwargs)
199 """Call target, and fall back on dispatchers if there is a TypeError."""
200 try:
--> 201 return target(*args, **kwargs)
202 except (TypeError, ValueError):
203 # Note: convert_to_eager_tensor currently raises a ValueError, not a
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/array_ops.py in identity(input, name)
285 # variables. Variables have correct handle data when graph building.
286 input = ops.convert_to_tensor(input)
--> 287 ret = gen_array_ops.identity(input, name=name)
288 # Propagate handle data for happier shape inference for resource variables.
289 if hasattr(input, "_handle_data"):
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/gen_array_ops.py in identity(input, name)
3941 # Add nodes to the TensorFlow graph.
3942 _, _, _op, _outputs = _op_def_library._apply_op_helper(
-> 3943 "Identity", input=input, name=name)
3944 _result = _outputs[:]
3945 if _execute.must_record_gradient():
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/op_def_library.py in _apply_op_helper(op_type_name, name, **keywords)
748 op = g._create_op_internal(op_type_name, inputs, dtypes=None,
749 name=scope, input_types=input_types,
--> 750 attrs=attr_protos, op_def=op_def)
751
752 # `outputs` is returned as a separate return value so that the output
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/func_graph.py in _create_op_internal(self, op_type, inputs, dtypes, input_types, name, attrs, op_def, compute_device)
590 return super(FuncGraph, self)._create_op_internal( # pylint: disable=protected-access
591 op_type, captured_inputs, dtypes, input_types, name, attrs, op_def,
--> 592 compute_device)
593
594 def capture(self, tensor, name=None, shape=None):
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py in _create_op_internal(self, op_type, inputs, dtypes, input_types, name, attrs, op_def, compute_device)
3525 # _create_op_helper mutates the new Operation. `_mutation_lock` ensures a
3526 # Session.run call cannot occur between creating and mutating the op.
-> 3527 with self._mutation_lock():
3528 ret = Operation(
3529 node_def,
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py in _mutation_lock(self)
5186 See the comment for self._group_lock for more info.
5187 """
-> 5188 return self._group_lock.group(_MUTATION_LOCK_GROUP)
5189
5190 def _session_run_lock(self):
KeyboardInterrupt:
@saxenasaurabh do you think you could route this appropriately?
It seems to reproduce with tf.function as well, e.g. sample = tf.function(tfp.distributions.Gamma(concentration=alpha, rate=beta).sample)()
In the meantime if you need a workaround to avoid retracing the gradient, you can write:
@tf.function
def gamma_sample_and_grad(alpha, beta):
return tfp.math.value_and_gradient(lambda a,b: tfp.distributions.Gamma(concentration=a, rate=b).sample(), (alpha, beta))
Very interesting bug. @allenlavoie any ideas?
cc: @rohan100jain
The TFP sampler has a custom gradient, presumably that's involved somehow.
On Mon, Oct 26, 2020 at 5:33 PM Saurabh Saxena notifications@github.com wrote:
Very interesting bug. @allenlavoie https://github.com/allenlavoie any ideas?
cc: @rohan100jain https://github.com/rohan100jain
— You are receiving this because you commented. Reply to this email directly, view it on GitHub https://github.com/tensorflow/probability/issues/1136#issuecomment-716836259, or unsubscribe https://github.com/notifications/unsubscribe-auth/AFJFSI6OMKSOUCBQLTKCJ4LSMXTKRANCNFSM4SYXTHXQ .
My guess is that (a) there's a while_loop involved, and (b) the while_loop keeps adding new side outputs to the forward pass when tf.function tries to adapt it to work with the tape, leading to an infinite loop. I have a change pending that will probably make it better, but if not will at least start printing warnings.
There is a while loop in the forward computation, but the backprop of the outputs of the while wrt the inputs should be overridden by custom_gradient, and not have a while loop.
Brian Patton | Software Engineer | bjp@google.com
On Tue, Oct 27, 2020 at 12:16 PM Allen Lavoie notifications@github.com wrote:
My guess is that (a) there's a while_loop involved, and (b) the while_loop keeps adding new side outputs to the forward pass when tf.function tries to adapt it to work with the tape, leading to an infinite loop. I have a change pending that will probably make it better, but if not will at least start printing warnings.
— You are receiving this because you commented. Reply to this email directly, view it on GitHub https://github.com/tensorflow/probability/issues/1136#issuecomment-717356030, or unsubscribe https://github.com/notifications/unsubscribe-auth/AFJFSI5AKVCGI2WB6SYVS3LSM3W5PANCNFSM4SYXTHXQ .
Right now we don't thoroughly apply stop_gradient to everything in the forward pass of custom_gradient, we just override the result. It's possible the while_loop gradient still gets built, e.g. if the custom_gradient backward function captures a tensor from the forward pass that isn't an output (and so for higher-order gradients would either need to throw a "no gradient defined" error or build non-custom gradients for forward pass tensors).
I tried the snippet with https://github.com/tensorflow/tensorflow/commit/07b75ffa453b1ec9189371244d21b6684503340b (the pending change I mentioned), and it no longer goes into an infinite loop. The change should be in tomorrow's nightly.
Brian, is the forward pass of the custom gradient being autodiffed a concern here? Otherwise this can probably be closed.
Same here. Specifically, I'm using a customized tfp distribution (Von Mises Fisher distribution), and there is a tensorflow.python.ops.control_flow_ops.while_loop in my sampling when the dimension is not 3. As a comparison, when the dimension is 3, there's no while loop in the sampling and the use of gradient tape is fine. Any suggestion? Thank you
def _sample_n(self, n, seed=None):
shape = array_ops.concat([[n], self.batch_shape_tensor()], 0)
w = control_flow_ops.cond(gen_math_ops.equal(self.__m, 3),
lambda: self.__sample_w3(n, seed),
lambda: self.__sample_w_rej(n, seed))
v = nn_impl.l2_normalize(array_ops.transpose(
array_ops.transpose(random_ops.random_normal(shape, dtype=self.dtype, seed=seed))[1:]), axis=-1)
x = array_ops.concat((w, math_ops.sqrt(1 - w ** 2) * v), axis=-1)
z = self.__householder_rotation(x)
return z
def __sample_w3(self, n, seed):
shape = array_ops.concat(([n], self.batch_shape_tensor()[:-1], [1]), 0)
u = random_ops.random_uniform(shape, dtype=self.dtype, seed=seed)
self.__w = 1 + math_ops.reduce_logsumexp([math_ops.log(u), math_ops.log(1 - u) - 2 * self.scale], axis=0) / self.scale
return self.__w
def __sample_w_rej(self, n, seed):
c = math_ops.sqrt((4 * (self.scale ** 2)) + (self.__mf - 1) ** 2)
b_true = (-2 * self.scale + c) / (self.__mf - 1)
# using Taylor approximation with a smooth swift from 10 < scale < 11
# to avoid numerical errors for large scale
b_app = (self.__mf - 1) / (4 * self.scale)
s = gen_math_ops.minimum(gen_math_ops.maximum(0., self.scale - 10), 1.)
b = b_app * s + b_true * (1 - s)
a = (self.__mf - 1 + 2 * self.scale + c) / 4
d = (4 * a * b) / (1 + b) - (self.__mf - 1) * math_ops.log(self.__mf - 1)
self.__b, (self.__e, self.__w) = b, self.__while_loop(b, a, d, n, seed)
return self.__w
def __while_loop(self, b, a, d, n, seed):
def __cond(w, e, bool_mask, b, a, d):
return math_ops.reduce_any(bool_mask)
def __body(w_, e_, bool_mask, b, a, d):
e = math_ops.cast(tfd.Beta((self.__mf - 1) / 2, (self.__mf - 1) / 2).sample(
shape, seed=seed), dtype=self.dtype)
u = random_ops.random_uniform(shape, dtype=self.dtype, seed=seed)
w = (1 - (1 + b) * e) / (1 - (1 - b) * e)
t = (2 * a * b) / (1 - (1 - b) * e)
accept = gen_math_ops.greater(((self.__mf - 1) * math_ops.log(t) - t + d), math_ops.log(u))
reject = gen_math_ops.logical_not(accept)
w_ = array_ops.where(gen_math_ops.logical_and(bool_mask, accept), w, w_)
e_ = array_ops.where(gen_math_ops.logical_and(bool_mask, accept), e, e_)
bool_mask = array_ops.where(gen_math_ops.logical_and(bool_mask, accept), reject, bool_mask)
return w_, e_, bool_mask, b, a, d
shape = array_ops.concat([[n], self.batch_shape_tensor()[:-1], [1]], 0)
b, a, d = [gen_array_ops.tile(array_ops.expand_dims(e, axis=0), [n] + [1] * len(e.shape)) for e in (b, a, d)]
w, e, bool_mask, b, a, d = control_flow_ops.while_loop(__cond, __body,
[array_ops.zeros_like(b, dtype=self.dtype),
array_ops.zeros_like(b, dtype=self.dtype),
array_ops.ones_like(b, dtypes.bool),
b, a, d])
return e, w
@Darkhunter9 I'd suggest using a TF nightly for the fix. It'll be in TF 2.5 when that comes out.
@allenlavoie Thanks for the reply! Currently on the machine I deploy training, only CUDA 10.1 is available. Is CUDA 11 necessary for TF >= 2.4?
Drawing samples from an instance of
tfp.distributions.Distribution
in apersistent=True
GradientTape context seems to hang:This seems to happen for TF in [2.3.0, 2.3.1, nightly], TFP in [0.11.0, 0.11.1, nightly], Python 3.7, Colab and Mac OS (have not tested release and platform combinations).