christiaanjs / treeflow

GNU General Public License v3.0
13 stars 4 forks source link

Custom tree gradient working in function mode #32

Closed christiaanjs closed 2 years ago

christiaanjs commented 3 years ago

Tests for tree likelihood gradients currently fail where both the custom gradient and Tensorflow function mode are used, e.g.

../../miniconda3/envs/libsbn/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py:885: in __call__
    result = self._call(*args, **kwds)
../../miniconda3/envs/libsbn/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py:929: in _call
    self._initialize(args, kwds, add_initializers_to=initializers)
../../miniconda3/envs/libsbn/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py:760: in _initialize
    *args, **kwds))
../../miniconda3/envs/libsbn/lib/python3.7/site-packages/tensorflow/python/eager/function.py:3059: in _get_concrete_function_internal_garbage_collected
    graph_function, _ = self._maybe_define_function(args, kwargs)
../../miniconda3/envs/libsbn/lib/python3.7/site-packages/tensorflow/python/eager/function.py:3456: in _maybe_define_function
    graph_function = self._create_graph_function(args, kwargs)
../../miniconda3/envs/libsbn/lib/python3.7/site-packages/tensorflow/python/eager/function.py:3301: in _create_graph_function
    capture_by_value=self._capture_by_value),
../../miniconda3/envs/libsbn/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py:1007: in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
../../miniconda3/envs/libsbn/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py:668: in wrapped_fn
    out = weak_wrapped_fn().__wrapped__(*args, **kwds)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

args = (<tf.Tensor 'x:0' shape=(4,) dtype=float64>,), kwargs = {}

    def wrapper(*args, **kwargs):
      """Calls a converted version of original_func."""
      # TODO(mdan): Push this block higher in tf.function's call stack.
      try:
        return autograph.converted_call(
            original_func,
            args,
            kwargs,
            options=autograph.ConversionOptions(
                recursive=True,
                optional_features=autograph_options,
                user_requested=True,
            ))
      except Exception as e:  # pylint:disable=broad-except
        if hasattr(e, "ag_error_metadata"):
>         raise e.ag_error_metadata.to_exception(e)
E         ValueError: in user code:
E         
E             /home/cswa648/dev/treeflow/test/test_sequences.py:134 grad  *
E                 return t.gradient(y, x)
E             /home/cswa648/miniconda3/envs/libsbn/lib/python3.7/site-packages/tensorflow/python/eager/backprop.py:1090 gradient  **
E                 unconnected_gradients=unconnected_gradients)
E             /home/cswa648/miniconda3/envs/libsbn/lib/python3.7/site-packages/tensorflow/python/eager/imperative_grad.py:77 imperative_grad
E                 compat.as_str(unconnected_gradients.value))
E             /home/cswa648/miniconda3/envs/libsbn/lib/python3.7/site-packages/tensorflow/python/eager/function.py:845 _backward_function
E                 return self._rewrite_forward_and_call_backward(call_op, *args)
E             /home/cswa648/miniconda3/envs/libsbn/lib/python3.7/site-packages/tensorflow/python/eager/function.py:760 _rewrite_forward_and_call_backward
E                 forward_function, backwards_function = self.forward_backward(len(doutputs))
E             /home/cswa648/miniconda3/envs/libsbn/lib/python3.7/site-packages/tensorflow/python/eager/function.py:693 forward_backward
E                 forward, backward = self._construct_forward_backward(num_doutputs)
E             /home/cswa648/miniconda3/envs/libsbn/lib/python3.7/site-packages/tensorflow/python/eager/function.py:741 _construct_forward_backward
E                 func_graph=backwards_graph)
E             /home/cswa648/miniconda3/envs/libsbn/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py:1007 func_graph_from_py_func
E                 func_outputs = python_func(*func_args, **func_kwargs)
E             /home/cswa648/miniconda3/envs/libsbn/lib/python3.7/site-packages/tensorflow/python/eager/function.py:731 _backprop_function
E                 src_graph=self._func_graph)
E             /home/cswa648/miniconda3/envs/libsbn/lib/python3.7/site-packages/tensorflow/python/ops/gradients_util.py:682 _GradientsHelper
E                 lambda: grad_fn(op, *out_grads))
E             /home/cswa648/miniconda3/envs/libsbn/lib/python3.7/site-packages/tensorflow/python/ops/gradients_util.py:338 _MaybeCompile
E                 return grad_fn()  # Exit early
E             /home/cswa648/miniconda3/envs/libsbn/lib/python3.7/site-packages/tensorflow/python/ops/gradients_util.py:682 <lambda>
E                 lambda: grad_fn(op, *out_grads))
E             /home/cswa648/miniconda3/envs/libsbn/lib/python3.7/site-packages/tensorflow/python/ops/while_v2.py:384 _WhileGrad
E                 stateful_parallelism)
E             /home/cswa648/miniconda3/envs/libsbn/lib/python3.7/site-packages/tensorflow/python/ops/while_v2.py:705 _create_grad_func
E                 (str(internal_capture), str(external_capture)))
E         
E             ValueError: Tensor Tensor("gradients/map/while_grad/gradients/map/while/IdentityN_grad/GatherV2/indices:0", shape=(4,), dtype=int32, device=/job:localhost/replica:0/task:0/device:CPU:0) which captures tf.Tensor([2 3 0 1], shape=(4,), dtype=int32) is in list of internal_captures but not in internal_capture_to_output.