tensorflow / addons

Useful extra functionality for TensorFlow 2.x maintained by SIG-addons
Apache License 2.0
1.69k stars 611 forks source link

Saving models with a BasicDecoder Layer? #2432

Open klaimans opened 3 years ago

klaimans commented 3 years ago

System information

Describe the bug

when trying to save a subclassed model with a BasicDecoder layer, one cannot save the model using the save method or tf.saved_model.save

Code to reproduce the issue

Using the colab tutorial

https://colab.research.google.com/github/tensorflow/addons/blob/master/docs/tutorials/networks_seq2seq_nmt.ipynb

if one tries to save the decoder it doesn't work.

Our subclassed model is more involved but we expect that if we cannot save the decoder already in this example it cannot work for us either.

Provide a reproducible test case that is the bare minimum necessary to generate the problem.

see colab above and add after the training cell the following line:

decoder.save("./test")


TypeError Traceback (most recent call last)

in () ----> 1 decoder.save("./test") 25 frames /usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/engine/training.py in save(self, filepath, overwrite, include_optimizer, save_format, signatures, options, save_traces) 2000 # pylint: enable=line-too-long 2001 save.save_model(self, filepath, overwrite, include_optimizer, save_format, -> 2002 signatures, options, save_traces) 2003 2004 def save_weights(self, /usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/saving/save.py in save_model(model, filepath, overwrite, include_optimizer, save_format, signatures, options, save_traces) 155 else: 156 saved_model_save.save(model, filepath, overwrite, include_optimizer, --> 157 signatures, options, save_traces) 158 159 /usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/saving/saved_model/save.py in save(model, filepath, overwrite, include_optimizer, signatures, options, save_traces) 87 with distribution_strategy_context._get_default_replica_context(): # pylint: disable=protected-access 88 with utils.keras_option_scope(save_traces): ---> 89 save_lib.save(model, filepath, signatures, options) 90 91 if not include_optimizer: /usr/local/lib/python3.7/dist-packages/tensorflow/python/saved_model/save.py in save(obj, export_dir, signatures, options) 1031 1032 _, exported_graph, object_saver, asset_info = _build_meta_graph( -> 1033 obj, signatures, options, meta_graph_def) 1034 saved_model.saved_model_schema_version = constants.SAVED_MODEL_SCHEMA_VERSION 1035 /usr/local/lib/python3.7/dist-packages/tensorflow/python/saved_model/save.py in _build_meta_graph(obj, signatures, options, meta_graph_def) 1196 1197 with save_context.save_context(options): -> 1198 return _build_meta_graph_impl(obj, signatures, options, meta_graph_def) /usr/local/lib/python3.7/dist-packages/tensorflow/python/saved_model/save.py in _build_meta_graph_impl(obj, signatures, options, meta_graph_def) 1131 if signatures is None: 1132 signatures = signature_serialization.find_function_to_export( -> 1133 checkpoint_graph_view) 1134 1135 signatures, wrapped_functions = ( /usr/local/lib/python3.7/dist-packages/tensorflow/python/saved_model/signature_serialization.py in find_function_to_export(saveable_view) 73 # If the user did not specify signatures, check the root object for a function 74 # that can be made into a signature. ---> 75 functions = saveable_view.list_functions(saveable_view.root) 76 signature = functions.get(DEFAULT_SIGNATURE_ATTR, None) 77 if signature is not None: /usr/local/lib/python3.7/dist-packages/tensorflow/python/saved_model/save.py in list_functions(self, obj, extra_functions) 149 if obj_functions is None: 150 obj_functions = obj._list_functions_for_serialization( # pylint: disable=protected-access --> 151 self._serialization_cache) 152 self._functions[obj] = obj_functions 153 if extra_functions: /usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/engine/training.py in _list_functions_for_serialization(self, serialization_cache) 2611 self.predict_function = None 2612 functions = super( -> 2613 Model, self)._list_functions_for_serialization(serialization_cache) 2614 self.train_function = train_function 2615 self.test_function = test_function /usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/engine/base_layer.py in _list_functions_for_serialization(self, serialization_cache) 3085 def _list_functions_for_serialization(self, serialization_cache): 3086 return (self._trackable_saved_model_saver -> 3087 .list_functions_for_serialization(serialization_cache)) 3088 3089 def __getstate__(self): /usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/saving/saved_model/base_serialization.py in list_functions_for_serialization(self, serialization_cache) 92 return {} 93 ---> 94 fns = self.functions_to_serialize(serialization_cache) 95 96 # The parent AutoTrackable class saves all user-defined tf.functions, and /usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/saving/saved_model/layer_serialization.py in functions_to_serialize(self, serialization_cache) 77 def functions_to_serialize(self, serialization_cache): 78 return (self._get_serialized_attributes( ---> 79 serialization_cache).functions_to_serialize) 80 81 def _get_serialized_attributes(self, serialization_cache): /usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/saving/saved_model/layer_serialization.py in _get_serialized_attributes(self, serialization_cache) 93 94 object_dict, function_dict = self._get_serialized_attributes_internal( ---> 95 serialization_cache) 96 97 serialized_attr.set_and_validate_objects(object_dict) /usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/saving/saved_model/model_serialization.py in _get_serialized_attributes_internal(self, serialization_cache) 49 # cache (i.e. this is the root level object). 50 if len(serialization_cache[constants.KERAS_CACHE_KEY]) == 1: ---> 51 default_signature = save_impl.default_save_signature(self.obj) 52 53 # Other than the default signature function, all other attributes match with /usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/saving/saved_model/save_impl.py in default_save_signature(layer) 203 original_losses = _reset_layer_losses(layer) 204 fn = saving_utils.trace_model_call(layer) --> 205 fn.get_concrete_function() 206 _restore_layer_losses(original_losses) 207 return fn /usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/def_function.py in get_concrete_function(self, *args, **kwargs) 1297 ValueError: if this object has not yet been called on concrete values. 1298 """ -> 1299 concrete = self._get_concrete_function_garbage_collected(*args, **kwargs) 1300 concrete._garbage_collector.release() # pylint: disable=protected-access 1301 return concrete /usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/def_function.py in _get_concrete_function_garbage_collected(self, *args, **kwargs) 1203 if self._stateful_fn is None: 1204 initializers = [] -> 1205 self._initialize(args, kwargs, add_initializers_to=initializers) 1206 self._initialize_uninitialized_variables(initializers) 1207 /usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/def_function.py in _initialize(self, args, kwds, add_initializers_to) 724 self._concrete_stateful_fn = ( 725 self._stateful_fn._get_concrete_function_internal_garbage_collected( # pylint: disable=protected-access --> 726 *args, **kwds)) 727 728 def invalid_creator_scope(*unused_args, **unused_kwds): /usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/function.py in _get_concrete_function_internal_garbage_collected(self, *args, **kwargs) 2967 args, kwargs = None, None 2968 with self._lock: -> 2969 graph_function, _ = self._maybe_define_function(args, kwargs) 2970 return graph_function 2971 /usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/function.py in _maybe_define_function(self, args, kwargs) 3359 3360 self._function_cache.missed.add(call_context_key) -> 3361 graph_function = self._create_graph_function(args, kwargs) 3362 self._function_cache.primary[cache_key] = graph_function 3363 /usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes) 3204 arg_names=arg_names, 3205 override_flat_arg_shapes=override_flat_arg_shapes, -> 3206 capture_by_value=self._capture_by_value), 3207 self._function_attributes, 3208 function_spec=self.function_spec, /usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes) 988 _, original_func = tf_decorator.unwrap(python_func) 989 --> 990 func_outputs = python_func(*func_args, **func_kwargs) 991 992 # invariant: `func_outputs` contains only Tensors, CompositeTensors, /usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/def_function.py in wrapped_fn(*args, **kwds) 632 xla_context.Exit() 633 else: --> 634 out = weak_wrapped_fn().__wrapped__(*args, **kwds) 635 return out 636 /usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/saving/saving_utils.py in _wrapped_model(*args) 133 with base_layer_utils.call_context().enter( 134 model, inputs=inputs, build_graph=False, training=False, saving=True): --> 135 outputs = model(inputs, training=False) 136 137 # Outputs always has to be a flat dict. /usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/engine/base_layer.py in __call__(self, *args, **kwargs) 1010 with autocast_variable.enable_auto_cast_variables( 1011 self._compute_dtype_object): -> 1012 outputs = call_fn(inputs, *args, **kwargs) 1013 1014 if self._activity_regularizer: /usr/local/lib/python3.7/dist-packages/tensorflow/python/autograph/impl/api.py in wrapper(*args, **kwargs) 618 def wrapper(*args, **kwargs): 619 with ag_ctx.ControlStatusCtx(status=ag_ctx.Status.DISABLED): --> 620 return func(*args, **kwargs) 621 622 if inspect.isfunction(func) or inspect.ismethod(func): TypeError: call() missing 1 required positional argument: 'initial_state' Include any logs or source code that would be helpful to diagnose the problem. If including tracebacks, please include the full traceback. Large logs and files should be attached.
bhack commented 3 years ago

Check https://stackoverflow.com/questions/60930158/tensorflow-saving-subclass-model-which-has-multiple-arguments-to-call-method

klaimans commented 3 years ago

Hi @bhack , Thank you very much for the very quick reply. Indeed that gets us past that error but if you try the same for the decoder rather than the encoder (which is what we are trying to do) it still fails. It is also unclear how to send the signature for the AttenstionWrapperState input to the decoder which is a NamedTuple and not a simple list.

I would appreciate any additional suggestions you might have. What we see in our subclassed model is that while we can create a concrete function for our predict step the save_model function still fails since it tries to trace the model with the wrong shaped tensors even though for our specific predict_step there is no problem and we define the signatures explicitly.

Thank you in advance!

guillaumekln commented 3 years ago

Did you manage to solve the issue? Did you try saving the decoder in a checkpoint instead?

If you need to save a self-contained graph, I think you should define a separate tf.function wrapping the decoder with a simpler signature. Then you could export this function with tf.saved_model.save.

klaimans commented 3 years ago

Hi @guillaumekln , sorry for the belated reply. Unfortunately we are still struggling with this. Using a checkpoint works without a problem and as I mentioned we can even create a concrete function which works perfectly. However, sending the model to tf.save_model.save still fails. We have tried different signatures also overriding some of the builtin signatures in tfa.Sampler and tfa.BaseDecoder but we couldn't resolve it yet.

We ended up solving this by using the logic for saving estimator based models. It is ugly but at least we got it to work and we could load the model and use it for inference.

We would really appreciate any additional suggestions.

guillaumekln commented 3 years ago

Here's an example that exports a BasicDecoder to a SavedModel:

import tempfile

import tensorflow as tf
import tensorflow_addons as tfa

class MyModel(tf.keras.layers.Layer):
    def __init__(self, vocab_size, num_units):
        super().__init__()
        self.embedding = tf.keras.layers.Embedding(vocab_size, num_units)
        self.cell = tf.keras.layers.LSTMCell(num_units)
        self.sampler = tfa.seq2seq.GreedyEmbeddingSampler(self.embedding)
        self.output_layer = tf.keras.layers.Dense(vocab_size)
        self.decoder = tfa.seq2seq.BasicDecoder(
            self.cell,
            self.sampler,
            self.output_layer,
            maximum_iterations=10,
        )

    @tf.function(
        input_signature=(
            tf.TensorSpec([None], dtype=tf.int32),
            tf.TensorSpec([], dtype=tf.int32),
        )
    )
    def run(self, start_tokens, end_token):
        batch_size = tf.shape(start_tokens)[0]
        initial_state = self.cell.get_initial_state(
            batch_size=batch_size, dtype=tf.float32
        )
        output, state, lengths = self.decoder(
            None,
            start_tokens=start_tokens,
            end_token=end_token,
            initial_state=initial_state,
        )
        return output

with tempfile.TemporaryDirectory() as export_dir:
    model = MyModel(512, 64)
    tf.saved_model.save(model, export_dir, signatures=model.run.get_concrete_function())
    del model

    imported = tf.saved_model.load(export_dir)
    function = imported.signatures["serving_default"]

    start_tokens = tf.constant([1, 2, 3], dtype=tf.int32)
    end_token = tf.constant(5, dtype=tf.int32)
    output = function(start_tokens=start_tokens, end_token=end_token)

    print(output)