tensorflow / lattice

Lattice methods in TensorFlow
Apache License 2.0
519 stars 95 forks source link

Can't save and load lattice model #77

Open win13676 opened 9 months ago

win13676 commented 9 months ago

hello, I'm having a problem with loading premade models (https://www.tensorflow.org/lattice/tutorials/premade_models)

when I save the model as .tf format and load the model I would get

KeyError: 'layers' 

when I save the model as keras format and load the model I would get

ValueError: Unknown object: 'CalibratedLatticeEnsembleConfig'. Please ensure you are using a `keras.utils.custom_object_scope` and that this object is included in the scope. See https://www.tensorflow.org/guide/keras/save_and_serialize#registering_the_custom_object for details.

I'm seeing this issue running on the colab given by the page (https://colab.research.google.com/github/tensorflow/lattice/blob/master/docs/tutorials/premade_models.ipynb), on Databricks, and also local run

code I use to save and load the model (https://www.tensorflow.org/lattice/api_docs/python/tfl/premade/CalibratedLattice#save)

rtl_layer_ensemble_model.save("model.keras")
loaded_model = tf.keras.models.load_model("model.keras")

rtl_layer_ensemble_model.save("model.tf")
loaded_model = tf.keras.models.load_model("model.tf")

rtl_layer_ensemble_model.save("/")
loaded_model = tf.keras.models.load_model("/")

all the models in the example can't be loaded

linear_model, lattice_model, explicit_ensemble_model, random_ensemble_model, rtl_layer_ensemble_model, prefitting_model, crystals_ensemble_model

full stacktrace error:

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
[<ipython-input-40-bc47832831ee>](https://localhost:8080/#) in <cell line: 2>()
      1 rtl_layer_ensemble_model.save("model.tf")
----> 2 loaded_model = tf.keras.models.load_model("model.tf")

2 frames
[/usr/local/lib/python3.10/dist-packages/keras/src/saving/saving_api.py](https://localhost:8080/#) in load_model(filepath, custom_objects, compile, safe_mode, **kwargs)
    236 
    237     # Legacy case.
--> 238     return legacy_sm_saving_lib.load_model(
    239         filepath, custom_objects=custom_objects, compile=compile, **kwargs
    240     )

[/usr/local/lib/python3.10/dist-packages/keras/src/utils/traceback_utils.py](https://localhost:8080/#) in error_handler(*args, **kwargs)
     68             # To get the full stack trace, call:
     69             # `tf.debugging.disable_traceback_filtering()`
---> 70             raise e.with_traceback(filtered_tb) from None
     71         finally:
     72             del filtered_tb

[/usr/local/lib/python3.10/dist-packages/keras/src/engine/functional.py](https://localhost:8080/#) in reconstruct_from_config(config, custom_objects, created_layers)
   1487 
   1488     # First, we create all layers and enqueue nodes to be processed
-> 1489     for layer_data in config["layers"]:
   1490         process_layer(layer_data)
   1491     # Then we process nodes in order of layer depth.

KeyError: 'layers'
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
[<ipython-input-39-c20ad91f21ca>](https://localhost:8080/#) in <cell line: 2>()
      1 rtl_layer_ensemble_model.save("model.keras")
----> 2 loaded_model = tf.keras.models.load_model("model.keras")

6 frames
[/usr/local/lib/python3.10/dist-packages/keras/src/saving/saving_api.py](https://localhost:8080/#) in load_model(filepath, custom_objects, compile, safe_mode, **kwargs)
    228                 f"with the native Keras format: {list(kwargs.keys())}"
    229             )
--> 230         return saving_lib.load_model(
    231             filepath,
    232             custom_objects=custom_objects,

[/usr/local/lib/python3.10/dist-packages/keras/src/saving/saving_lib.py](https://localhost:8080/#) in load_model(filepath, custom_objects, compile, safe_mode)
    273 
    274     except Exception as e:
--> 275         raise e
    276     else:
    277         return model

[/usr/local/lib/python3.10/dist-packages/keras/src/saving/saving_lib.py](https://localhost:8080/#) in load_model(filepath, custom_objects, compile, safe_mode)
    238             # Construct the model from the configuration file in the archive.
    239             with ObjectSharingScope():
--> 240                 model = deserialize_keras_object(
    241                     config_dict, custom_objects, safe_mode=safe_mode
    242                 )

[/usr/local/lib/python3.10/dist-packages/keras/src/saving/serialization_lib.py](https://localhost:8080/#) in deserialize_keras_object(config, custom_objects, safe_mode, **kwargs)
    702     safe_mode_scope = SafeModeScope(safe_mode)
    703     with custom_obj_scope, safe_mode_scope:
--> 704         instance = cls.from_config(inner_config)
    705         build_config = config.get("build_config", None)
    706         if build_config:

[/usr/local/lib/python3.10/dist-packages/tensorflow_lattice/python/premade.py](https://localhost:8080/#) in from_config(cls, config, custom_objects)
    146   @classmethod
    147   def from_config(cls, config, custom_objects=None):
--> 148     model_config = tf.keras.utils.legacy.deserialize_keras_object(
    149         config.get('model_config'), custom_objects=custom_objects
    150     )

[/usr/local/lib/python3.10/dist-packages/keras/src/saving/legacy/serialization.py](https://localhost:8080/#) in deserialize_keras_object(identifier, module_objects, custom_objects, printable_module_name)
    478         # In this case we are dealing with a Keras config dictionary.
    479         config = identifier
--> 480         (cls, cls_config) = class_and_config_for_serialized_keras_object(
    481             config, module_objects, custom_objects, printable_module_name
    482         )

[/usr/local/lib/python3.10/dist-packages/keras/src/saving/legacy/serialization.py](https://localhost:8080/#) in class_and_config_for_serialized_keras_object(config, module_objects, custom_objects, printable_module_name)
    363     )
    364     if cls is None:
--> 365         raise ValueError(
    366             f"Unknown {printable_module_name}: '{class_name}'. "
    367             "Please ensure you are using a `keras.utils.custom_object_scope` "

ValueError: Unknown object: 'CalibratedLatticeEnsembleConfig'. Please ensure you are using a `keras.utils.custom_object_scope` and that this object is included in the scope. See https://www.tensorflow.org/guide/keras/save_and_serialize#registering_the_custom_object for details.
win13676 commented 9 months ago

when I changed the version to 2.0.11, the keras failed with the error below and tf passed I think the issue happens in all version after 2.0.11

!pip install tensorflow-lattice==2.0.11 pydot
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
[<ipython-input-23-c20ad91f21ca>](https://localhost:8080/#) in <cell line: 2>()
      1 rtl_layer_ensemble_model.save("model.keras")
----> 2 loaded_model = tf.keras.models.load_model("model.keras")

6 frames
[/usr/local/lib/python3.10/dist-packages/keras/src/saving/saving_api.py](https://localhost:8080/#) in load_model(filepath, custom_objects, compile, safe_mode, **kwargs)
    228                 f"with the native Keras format: {list(kwargs.keys())}"
    229             )
--> 230         return saving_lib.load_model(
    231             filepath,
    232             custom_objects=custom_objects,

[/usr/local/lib/python3.10/dist-packages/keras/src/saving/saving_lib.py](https://localhost:8080/#) in load_model(filepath, custom_objects, compile, safe_mode)
    273 
    274     except Exception as e:
--> 275         raise e
    276     else:
    277         return model

[/usr/local/lib/python3.10/dist-packages/keras/src/saving/saving_lib.py](https://localhost:8080/#) in load_model(filepath, custom_objects, compile, safe_mode)
    238             # Construct the model from the configuration file in the archive.
    239             with ObjectSharingScope():
--> 240                 model = deserialize_keras_object(
    241                     config_dict, custom_objects, safe_mode=safe_mode
    242                 )

[/usr/local/lib/python3.10/dist-packages/keras/src/saving/serialization_lib.py](https://localhost:8080/#) in deserialize_keras_object(config, custom_objects, safe_mode, **kwargs)
    702     safe_mode_scope = SafeModeScope(safe_mode)
    703     with custom_obj_scope, safe_mode_scope:
--> 704         instance = cls.from_config(inner_config)
    705         build_config = config.get("build_config", None)
    706         if build_config:

[/usr/local/lib/python3.10/dist-packages/tensorflow_lattice/python/premade.py](https://localhost:8080/#) in from_config(cls, config, custom_objects)
    145   @classmethod
    146   def from_config(cls, config, custom_objects=None):
--> 147     model = super(CalibratedLatticeEnsemble, cls).from_config(
    148         config, custom_objects=custom_objects)
    149     try:

[/usr/local/lib/python3.10/dist-packages/keras/src/engine/training.py](https://localhost:8080/#) in from_config(cls, config, custom_objects)
   3242                 # constructor of the class.
   3243                 try:
-> 3244                     model = cls(**config)
   3245                 except TypeError as e:
   3246                     raise TypeError(

[/usr/local/lib/python3.10/dist-packages/tensorflow_lattice/python/premade.py](https://localhost:8080/#) in __init__(self, model_config, dtype, **kwargs)
     98     # Check that proper config has been given.
     99     if not isinstance(model_config, configs.CalibratedLatticeEnsembleConfig):
--> 100       raise ValueError('Invalid config type: {}'.format(type(model_config)))
    101     # Verify that the config is fully specified.
    102     premade_lib.verify_config(model_config)

ValueError: Invalid config type: <class 'dict'>
siriuz42 commented 9 months ago

Try

rtl_layer_ensemble_model.save("model_keras")
loaded_model = tf.keras.models.load_model(
    "model_keras",
     custom_objects=tfl.premade.get_custom_objects(),
)
  1. tf.keras.models.load_model reconstructs the Keras model, thus you need to pass in the custom objects used by the model in order to be able to recompile it. tfl.premade.get_custom_objects() returns all Tensorflow Lattice custom objects. If you only need the model for inference, i.e. a functional __call__, you can instead use

    rtl_layer_ensemble_model.save("model_keras")
    loaded_model = tf.saved_model.load("model_keras")
  2. The period in keras.model is likely causing problems during variable name matching / parsing. Escaping it solves the problem.

win13676 commented 9 months ago

thank you for the suggestion

load_model with custom_objects=tfl.premade.get_custom_objects() doesn't throw error for when load model save with .save("model.tf") and .save("model") .save("model.keras") gave the error

ValueError: Input keypoints are invalid for feature age: {'class_name': '__numpy__', 'config': {'value': [29.0, 44.0, 54.0, 65.0, 100.0], 'dtype': 'float64'}}

however in version 2.0.13, the loaded model with custom_objects gave the following error when call .evaluate(x, y), .predict(x) seems to work

RuntimeError: You must compile your model before training/testing. Use `model.compile(optimizer, loss)`.

in version 2.0.11, saved and loaded model can call .evaluate(x, y)

escaping period doesn't seems to do anything or is there any special syntax for escape period besides "model\.keras"