keras-team / tf-keras

The TensorFlow-specific implementation of the Keras API, which was the default Keras from 2019 to 2023.
Apache License 2.0
60 stars 28 forks source link

Unable to restore a layer of class TextVectorization - Text Classification #116

Open sachinprasadhs opened 2 years ago

sachinprasadhs commented 2 years ago

Moving user issue from: https://github.com/tensorflow/tensorflow/issues/45231

Describe the problem.

**When I run the example provided by official tensorflow Basic text classification, everything runs fine until model save. But when I load the model it gives me this error.

RuntimeError: Unable to restore a layer of class TextVectorization. Layers of class TextVectorization require that the class be provided to the model loading code, either by registering the class using @keras.utils.register_keras_serializable on the class def and including that file in your program, or by passing the class in a keras.utils.CustomObjectScope that wraps this load call. **

Model should be loaded successfully and process raw input

https://colab.research.google.com/gist/amahendrakar/8b65a688dc87ce9ca07ffb0ce50b84c7/44199.ipynb#scrollTo=fEjmSrKIqiiM

Example Link: https://tensorflow.google.cn/tutorials/keras/text_classification

sachinprasadhs commented 2 years ago

Attaching the gist with reproducing error here.

The reported error can be avoided by registering the class with @keras.utils.register_keras_serializable() here is the working gist.

However, this comment from the user https://github.com/tensorflow/tensorflow/issues/45231#issuecomment-1026512621 does not agree with the above approach.

LukeWood commented 2 years ago

Notes from triage: The error message can be improved here - as the issue is with the standardize argument, not the layer

mattdangerw commented 2 years ago

It is expected that when passing a custom callable to either the standardize or split arguments of TextVectorization, that the function will need to be registered with register_keras_serializable or passed in the custom_objects argument during loading.

We should improve the error message here though, and make it clear this is an issue with serializing the argument to the layer and not the layer itself.

tmbluth commented 2 years ago

Attaching the gist with reproducing error here.

The reported error can be avoided by registering the class with @keras.utils.register_keras_serializable() here is the working gist.

However, this comment from the user tensorflow/tensorflow#45231 (comment) does not agree with the above approach.

This approach helps for those loading the model into the same notebook that they trained it in but it still does not address loading the same model in a different notebook. If you open a new notebook that can access the saved model and run the last cell from the training notebook it will error out.

In order to load it back in given the new notebook context you must run:

@tf.keras.utils.register_keras_serializable()
def custom_standardization(input_data):
  lowercase = tf.strings.lower(input_data)
  stripped_html = tf.strings.regex_replace(lowercase, '<br />', ' ')
  return tf.strings.regex_replace(stripped_html,
                                  '[%s]' % re.escape(string.punctuation),
                                  '')

# load model
loaded_model = tf.keras.models.load_model('./model/test/basic-text-class-export')
print(loaded_model.summary())

This seems like a poor solution for someone that is trying to re-load the model from a different notebook, especially if they don't know how custom_standarization was constructed in the first place. If that's the case, they are stuck

mihailyanchev commented 2 years ago

I confirm that this is a problem not only when working with notebooks, but with custom model deployments and possibly tensorflow serving.

Justin-monteilhet commented 1 year ago

+1