huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
131.73k stars 26.23k forks source link

Can't save model in saved_model format when finetune bert in tensorflow2 #13742

Closed SysuJayce closed 2 years ago

SysuJayce commented 2 years ago

Environment info

Who can help

@LysandreJik @Rocketknight1

Information

Model I am using (Bert, XLNet ...): roberta

The problem arises when using:

The tasks I am working on is:

To reproduce

Steps to reproduce the behavior:

class TFBertForMultilabelClassification(TFBertPreTrainedModel):

    def __init__(self, config, *inputs, **kwargs):
        super(TFBertForMultilabelClassification, self).__init__(config, *inputs, **kwargs)
        self.num_labels = config.num_labels
        self.bert = TFBertMainLayer(config, name='bert')
        self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
        self.classifier = tf.keras.layers.Dense(config.num_labels,
                                                kernel_initializer=get_initializer(config.initializer_range),
                                                name='classifier',
                                                activation='sigmoid')#--------------------- sigmoid激活函数

    def call(self, inputs, **kwargs):
        outputs = self.bert(inputs, **kwargs)
        pooled_output = outputs[1]
        pooled_output = self.dropout(pooled_output, training=kwargs.get('training', False))
        logits = self.classifier(pooled_output)
        outputs = (logits,) + outputs[2:]  # add hidden states and attention if they are here
        return outputs  # logits, (hidden_states), (attentions)

model = TFBertForMultilabelClassification.from_pretrained("bert-base-uncased")
model.save("/tmp/model")

Error messages:

Some layers from the model checkpoint at bert-base-uncased were not used when initializing TFBertForMultilabelClassification: ['nsp___cls', 'mlm___cls']
- This IS expected if you are initializing TFBertForMultilabelClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TFBertForMultilabelClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some layers of TFBertForMultilabelClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['dropout_2326', 'classifier']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-222-0c8afc02744c> in <module>
     20 
     21 model = TFBertForMultilabelClassification.from_pretrained("bert-base-uncased")
---> 22 model.save("test")

/usr/local/Caskroom/miniconda/base/envs/tms/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py in save(self, filepath, overwrite, include_optimizer, save_format, signatures, options, save_traces)
   1994     """
   1995     # pylint: enable=line-too-long
-> 1996     save.save_model(self, filepath, overwrite, include_optimizer, save_format,
   1997                     signatures, options, save_traces)
   1998 

/usr/local/Caskroom/miniconda/base/envs/tms/lib/python3.8/site-packages/tensorflow/python/keras/saving/save.py in save_model(model, filepath, overwrite, include_optimizer, save_format, signatures, options, save_traces)
    154         model, filepath, overwrite, include_optimizer)
    155   else:
--> 156     saved_model_save.save(model, filepath, overwrite, include_optimizer,
    157                           signatures, options, save_traces)
    158 

/usr/local/Caskroom/miniconda/base/envs/tms/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/save.py in save(model, filepath, overwrite, include_optimizer, signatures, options, save_traces)
     87     with distribution_strategy_context._get_default_replica_context():  # pylint: disable=protected-access
     88       with utils.keras_option_scope(save_traces):
---> 89         save_lib.save(model, filepath, signatures, options)
     90 
     91   if not include_optimizer:

/usr/local/Caskroom/miniconda/base/envs/tms/lib/python3.8/site-packages/tensorflow/python/saved_model/save.py in save(obj, export_dir, signatures, options)
   1030   meta_graph_def = saved_model.meta_graphs.add()
   1031 
-> 1032   _, exported_graph, object_saver, asset_info = _build_meta_graph(
   1033       obj, signatures, options, meta_graph_def)
   1034   saved_model.saved_model_schema_version = constants.SAVED_MODEL_SCHEMA_VERSION

/usr/local/Caskroom/miniconda/base/envs/tms/lib/python3.8/site-packages/tensorflow/python/saved_model/save.py in _build_meta_graph(obj, signatures, options, meta_graph_def)
   1196 
   1197   with save_context.save_context(options):
-> 1198     return _build_meta_graph_impl(obj, signatures, options, meta_graph_def)

/usr/local/Caskroom/miniconda/base/envs/tms/lib/python3.8/site-packages/tensorflow/python/saved_model/save.py in _build_meta_graph_impl(obj, signatures, options, meta_graph_def)
   1145   # Note we run this twice since, while constructing the view the first time
   1146   # there can be side effects of creating variables.
-> 1147   _ = _SaveableView(checkpoint_graph_view, options)
   1148   saveable_view = _SaveableView(checkpoint_graph_view, options,
   1149                                 wrapped_functions)

/usr/local/Caskroom/miniconda/base/envs/tms/lib/python3.8/site-packages/tensorflow/python/saved_model/save.py in __init__(self, checkpoint_view, options, wrapped_functions)
    223           #  variables on first run.
    224           concrete_functions = (
--> 225               function._list_all_concrete_functions_for_serialization())  # pylint: disable=protected-access
    226         else:
    227           concrete_functions = [function]

/usr/local/Caskroom/miniconda/base/envs/tms/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py in _list_all_concrete_functions_for_serialization(self)
   1160       A list of instances of `ConcreteFunction`.
   1161     """
-> 1162     concrete_functions = self._list_all_concrete_functions()
   1163     seen_signatures = []
   1164     for concrete_function in concrete_functions:

/usr/local/Caskroom/miniconda/base/envs/tms/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py in _list_all_concrete_functions(self)
   1142     """Returns all concrete functions."""
   1143     if self.input_signature is not None:
-> 1144       self.get_concrete_function()
   1145     concrete_functions = []
   1146     # pylint: disable=protected-access

/usr/local/Caskroom/miniconda/base/envs/tms/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py in get_concrete_function(self, *args, **kwargs)
   1297       ValueError: if this object has not yet been called on concrete values.
   1298     """
-> 1299     concrete = self._get_concrete_function_garbage_collected(*args, **kwargs)
   1300     concrete._garbage_collector.release()  # pylint: disable=protected-access
   1301     return concrete

/usr/local/Caskroom/miniconda/base/envs/tms/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py in _get_concrete_function_garbage_collected(self, *args, **kwargs)
   1203       if self._stateful_fn is None:
   1204         initializers = []
-> 1205         self._initialize(args, kwargs, add_initializers_to=initializers)
   1206         self._initialize_uninitialized_variables(initializers)
   1207 

/usr/local/Caskroom/miniconda/base/envs/tms/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py in _initialize(self, args, kwds, add_initializers_to)
    723     self._graph_deleter = FunctionDeleter(self._lifted_initializer_graph)
    724     self._concrete_stateful_fn = (
--> 725         self._stateful_fn._get_concrete_function_internal_garbage_collected(  # pylint: disable=protected-access
    726             *args, **kwds))
    727 

/usr/local/Caskroom/miniconda/base/envs/tms/lib/python3.8/site-packages/tensorflow/python/eager/function.py in _get_concrete_function_internal_garbage_collected(self, *args, **kwargs)
   2967       args, kwargs = None, None
   2968     with self._lock:
-> 2969       graph_function, _ = self._maybe_define_function(args, kwargs)
   2970     return graph_function
   2971 

/usr/local/Caskroom/miniconda/base/envs/tms/lib/python3.8/site-packages/tensorflow/python/eager/function.py in _maybe_define_function(self, args, kwargs)
   3359 
   3360           self._function_cache.missed.add(call_context_key)
-> 3361           graph_function = self._create_graph_function(args, kwargs)
   3362           self._function_cache.primary[cache_key] = graph_function
   3363 

/usr/local/Caskroom/miniconda/base/envs/tms/lib/python3.8/site-packages/tensorflow/python/eager/function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes)
   3194     arg_names = base_arg_names + missing_arg_names
   3195     graph_function = ConcreteFunction(
-> 3196         func_graph_module.func_graph_from_py_func(
   3197             self._name,
   3198             self._python_function,

/usr/local/Caskroom/miniconda/base/envs/tms/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes)
    988         _, original_func = tf_decorator.unwrap(python_func)
    989 
--> 990       func_outputs = python_func(*func_args, **func_kwargs)
    991 
    992       # invariant: `func_outputs` contains only Tensors, CompositeTensors,

/usr/local/Caskroom/miniconda/base/envs/tms/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py in wrapped_fn(*args, **kwds)
    632             xla_context.Exit()
    633         else:
--> 634           out = weak_wrapped_fn().__wrapped__(*args, **kwds)
    635         return out
    636 

/usr/local/Caskroom/miniconda/base/envs/tms/lib/python3.8/site-packages/tensorflow/python/eager/function.py in bound_method_wrapper(*args, **kwargs)
   3885     # However, the replacer is still responsible for attaching self properly.
   3886     # TODO(mdan): Is it possible to do it here instead?
-> 3887     return wrapped_fn(*args, **kwargs)
   3888   weak_bound_method_wrapper = weakref.ref(bound_method_wrapper)
   3889 

/usr/local/Caskroom/miniconda/base/envs/tms/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py in wrapper(*args, **kwargs)
    975           except Exception as e:  # pylint:disable=broad-except
    976             if hasattr(e, "ag_error_metadata"):
--> 977               raise e.ag_error_metadata.to_exception(e)
    978             else:
    979               raise

TypeError: in user code:

    /usr/local/Caskroom/miniconda/base/envs/tms/lib/python3.8/site-packages/transformers/modeling_tf_utils.py:682 serving  *
        return self.serving_output(output)

    TypeError: tf__serving_output() takes 1 positional argument but 2 were given

Expected behavior

model can save in saved_model format without error

LysandreJik commented 2 years ago

cc @Rocketknight1

The models from transformers should be saved using the save_pretrained method. For TensorFlow models, you can obtain the result as a SavedModel by using the saved_model keyword argument:

model.save_pretrained("/tmp/model", saved_model=True)
SysuJayce commented 2 years ago

@LysandreJik Thx for your reply. I still have some questions, please take a look.

  1. What is the diff between model.save_pretrained(path, saved_model=True) and model.save(path)? If I can save mode with the latter, should I change to the former?

  2. Besides, I notice that when saving model, warning occurs. Is it a warning that can be ignored safely? What can I do to fix the warning?

    WARNING:absl:Found untraced functions such as embeddings_layer_call_and_return_conditional_losses, embeddings_layer_call_fn, encoder_layer_call_and_return_conditional_losses, encoder_layer_call_fn, pooler_layer_call_and_return_conditional_losses while saving (showing 5 of 1055). These functions will not be directly callable after loading.
    WARNING:absl:Found untraced functions such as embeddings_layer_call_and_return_conditional_losses, embeddings_layer_call_fn, encoder_layer_call_and_return_conditional_losses, encoder_layer_call_fn, pooler_layer_call_and_return_conditional_losses while saving (showing 5 of 1055). These functions will not be directly callable after loading.
Rocketknight1 commented 2 years ago

Hi @SysuJayce yes, you can ignore that warning. That warning can also pop up when saving a large model using model.save, it's just telling you that the model has some methods that weren't saved/traced, which is normal. Don't worry about fixing it.

Also, in general we don't support model.save because our 'standard' way of saving/loading models is to use save_pretrained and then from_pretrained to load it again, like model = TFBertForMultilabelClassification.from_pretrained("/tmp/model").

The reasons we do this instead of using SavedModel are bit long and confusing - the key issue is that SavedModel saves the model graph but not necessarily all the code, and so you won't necessarily have all the capabilities of the model when you reload the SavedModel file - this is basically what the warning is telling you. If you just want to call the model with new data, then SavedModel should work fine for you, but try passing the file path to the from_pretrained method if you want to load it perfectly and have it work just like it did originally.

SysuJayce commented 2 years ago

Hi @SysuJayce yes, you can ignore that warning. That warning can also pop up when saving a large model using model.save, it's just telling you that the model has some methods that weren't saved/traced, which is normal. Don't worry about fixing it.

Also, in general we don't support model.save because our 'standard' way of saving/loading models is to use save_pretrained and then from_pretrained to load it again, like model = TFBertForMultilabelClassification.from_pretrained("/tmp/model").

The reasons we do this instead of using SavedModel are bit long and confusing - the key issue is that SavedModel saves the model graph but not necessarily all the code, and so you won't necessarily have all the capabilities of the model when you reload the SavedModel file - this is basically what the warning is telling you. If you just want to call the model with new data, then SavedModel should work fine for you, but try passing the file path to the from_pretrained method if you want to load it perfectly and have it work just like it did originally.

Hello @Rocketknight1 ,

In my situation, I'd like to train with huggingface transformers, and serve with tensorflow serving. Therefore, I want to save and load the trained model with model.save().

Now I know that the preferred way to save and load trained model is from_pretrained() and save_pretrained(), but we can try tensorflow's original model.save() and model.load().

Maybe I can save trained model with model.save_pretrained(path, saved_model=True) and serve with tensorflow serving? What's your advice?

Thanks for your reply.

github-actions[bot] commented 2 years ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.