tensorflow / addons

Useful extra functionality for TensorFlow 2.x maintained by SIG-addons
Apache License 2.0
1.69k stars 613 forks source link

BasicDecoder used in tf.keras functional model in inference mode causes autograph error #2372

Open breadbread1984 opened 3 years ago

breadbread1984 commented 3 years ago

System information

Describe the bug

when building a tf.keras.Model in functional style. the BasicDecoder working in inference mode activates autograph and fails. here is an error FYI

Traceback (most recent call last):
  File "models.py", line 207, in <module>
    nmt = NMT(100, 200, 64, infer_params = infer_params);
  File "models.py", line 140, in NMT
    output = Decoder(inputs, targets if is_train == True else None, hidden, cell, decoder_cell, tgt_vocab_size, input_dims, is_train, infer_params);
  File "models.py", line 117, in Decoder
    output, state, lengths = decoder(None, start_tokens = start_tokens, end_token = infer_params['end_token'], initial_state = initial_state);
  File "/usr/local/lib/python3.8/dist-packages/tensorflow/python/keras/engine/base_layer.py", line 951, in __call__
    return self._functional_construction_call(inputs, args, kwargs,
  File "/usr/local/lib/python3.8/dist-packages/tensorflow/python/keras/engine/base_layer.py", line 1090, in _functional_construction_call
    outputs = self._keras_tensor_symbolic_call(
  File "/usr/local/lib/python3.8/dist-packages/tensorflow/python/keras/engine/base_layer.py", line 822, in _keras_tensor_symbolic_call
    return self._infer_output_signature(inputs, args, kwargs, input_masks)
  File "/usr/local/lib/python3.8/dist-packages/tensorflow/python/keras/engine/base_layer.py", line 863, in _infer_output_signature
    outputs = call_fn(inputs, *args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/tensorflow/python/autograph/impl/api.py", line 670, in wrapper
    raise e.ag_error_metadata.to_exception(e)
TypeError: in user code:

    /usr/local/lib/python3.8/dist-packages/tensorflow_addons/seq2seq/decoder.py:163 call  *
        self,
    /usr/local/lib/python3.8/dist-packages/typeguard/__init__.py:262 wrapper  *
        retval = func(*args, **kwargs)
    /usr/local/lib/python3.8/dist-packages/tensorflow_addons/seq2seq/decoder.py:321 dynamic_decode  *
        tf.debugging.assert_greater(
    /usr/local/lib/python3.8/dist-packages/tensorflow/python/util/dispatch.py:205 wrapper  **
        result = dispatch(wrapper, args, kwargs)
    /usr/local/lib/python3.8/dist-packages/tensorflow/python/util/dispatch.py:122 dispatch
        result = dispatcher.handle(op, args, kwargs)
    /usr/local/lib/python3.8/dist-packages/tensorflow/python/keras/layers/core.py:1450 handle
        return TFOpLambda(op)(*args, **kwargs)
    /usr/local/lib/python3.8/dist-packages/tensorflow/python/keras/engine/base_layer.py:951 __call__
        return self._functional_construction_call(inputs, args, kwargs,
    /usr/local/lib/python3.8/dist-packages/tensorflow/python/keras/engine/base_layer.py:1090 _functional_construction_call
        outputs = self._keras_tensor_symbolic_call(
    /usr/local/lib/python3.8/dist-packages/tensorflow/python/keras/engine/base_layer.py:822 _keras_tensor_symbolic_call
        return self._infer_output_signature(inputs, args, kwargs, input_masks)
    /usr/local/lib/python3.8/dist-packages/tensorflow/python/keras/engine/base_layer.py:868 _infer_output_signature
        outputs = nest.map_structure(
    /usr/local/lib/python3.8/dist-packages/tensorflow/python/util/nest.py:659 map_structure
        structure[0], [func(*x) for x in entries],
    /usr/local/lib/python3.8/dist-packages/tensorflow/python/util/nest.py:659 <listcomp>
        structure[0], [func(*x) for x in entries],
    /usr/local/lib/python3.8/dist-packages/tensorflow/python/keras/engine/keras_tensor.py:606 keras_tensor_from_tensor
        out = keras_tensor_cls.from_tensor(tensor)
    /usr/local/lib/python3.8/dist-packages/tensorflow/python/keras/engine/keras_tensor.py:205 from_tensor
        type_spec = type_spec_module.type_spec_from_value(tensor)
    /usr/local/lib/python3.8/dist-packages/tensorflow/python/framework/type_spec.py:553 type_spec_from_value
        raise TypeError("Could not build a TypeSpec for %r with type %s" %

    TypeError: Could not build a TypeSpec for <tf.Operation 'tf.debugging.assert_greater/assert_greater/Assert/Assert' type=Assert> with type Operation

Code to reproduce the issue

running code here reproduces the error at line 117. this code is to refactor tensorflow/nmt with tf.keras APIs.

Other info / logs

Include any logs or source code that would be helpful to diagnose the problem. If including tracebacks, please include the full traceback. Large logs and files should be attached.

guillaumekln commented 3 years ago

This is similar to previous issues related to Keras (see e.g. https://github.com/tensorflow/addons/issues/1898).

As far as I know the main culprits are the tf.debugging functions that don't accept Keras tensors. This is apparently by design (see https://github.com/tensorflow/tensorflow/issues/41627). The workaround is probably to wrap these calls under a Lambda layer when used in a Keras context.