keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
61.26k stars 19.37k forks source link

Load a model saved using keras 2.3.1 in keras 3 (containing bidirectionnal LSTM) #19898

Open tiatariene opened 1 week ago

tiatariene commented 1 week ago

Hi,

I need to load an old model trained using keras 2.3 (i don't know the tensorflow version), which contains two bidirecitonnal LSTM layers, but it stops at the loading of the first layer.

Is there any hope I could still load this model using keras 3 ?

Thank you.

Here is the message error I get :

  return keras.models.load_model(path_model, compile=False)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dxbz2376/miniconda3/envs/main_env312/lib/python3.12/site-packages/keras/src/saving/saving_api.py", line 183, in load_model
    return legacy_h5_format.load_model_from_hdf5(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dxbz2376/miniconda3/envs/main_env312/lib/python3.12/site-packages/keras/src/legacy/saving/legacy_h5_format.py", line 133, in load_model_from_hdf5
    model = saving_utils.model_from_config(
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dxbz2376/miniconda3/envs/main_env312/lib/python3.12/site-packages/keras/src/legacy/saving/saving_utils.py", line 85, in model_from_config
    return serialization.deserialize_keras_object(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dxbz2376/miniconda3/envs/main_env312/lib/python3.12/site-packages/keras/src/legacy/saving/serialization.py", line 495, in deserialize_keras_object
    deserialized_obj = cls.from_config(
                       ^^^^^^^^^^^^^^^^
  File "/home/dxbz2376/miniconda3/envs/main_env312/lib/python3.12/site-packages/keras/src/models/model.py", line 517, in from_config
    return functional_from_config(
           ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dxbz2376/miniconda3/envs/main_env312/lib/python3.12/site-packages/keras/src/models/functional.py", line 517, in functional_from_config
    process_layer(layer_data)
  File "/home/dxbz2376/miniconda3/envs/main_env312/lib/python3.12/site-packages/keras/src/models/functional.py", line 497, in process_layer
    layer = saving_utils.model_from_config(
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dxbz2376/miniconda3/envs/main_env312/lib/python3.12/site-packages/keras/src/legacy/saving/saving_utils.py", line 85, in model_from_config
    return serialization.deserialize_keras_object(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dxbz2376/miniconda3/envs/main_env312/lib/python3.12/site-packages/keras/src/legacy/saving/serialization.py", line 495, in deserialize_keras_object
    deserialized_obj = cls.from_config(
                       ^^^^^^^^^^^^^^^^
  File "/home/dxbz2376/miniconda3/envs/main_env312/lib/python3.12/site-packages/keras/src/layers/rnn/bidirectional.py", line 314, in from_config
    config["layer"] = serialization_lib.deserialize_keras_object(
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dxbz2376/miniconda3/envs/main_env312/lib/python3.12/site-packages/keras/src/saving/serialization_lib.py", line 694, in deserialize_keras_object
    cls = _retrieve_class_or_fn(
          ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dxbz2376/miniconda3/envs/main_env312/lib/python3.12/site-packages/keras/src/saving/serialization_lib.py", line 812, in _retrieve_class_or_fn
    raise TypeError(
TypeError: Could not locate class 'LSTM'. Make sure custom classes are decorated with `@keras.saving.register_keras_serializable()`. Full object config: {'class_name': 'LSTM', 'config': {'name': 'lstm_1', 'trainable': True, 'dtype': 'float32', 'return_sequences': True, 'return_state': False, 'go_backwards': False, 'stateful': False, 'unroll': False, 'units': 64, 'activation': 'tanh', 'recurrent_activation': 'sigmoid', 'use_bias': True, 'kernel_initializer': {'class_name': 'VarianceScaling', 'config': {'scale': 1.0, 'mode': 'fan_avg', 'distribution': 'uniform', 'seed': None}}, 'recurrent_initializer': {'class_name': 'Orthogonal', 'config': {'gain': 1.0, 'seed': None}}, 'bias_initializer': {'class_name': 'Zeros', 'config': {}}, 'unit_forget_bias': True, 'kernel_regularizer': None, 'recurrent_regularizer': None, 'bias_regularizer': None, 'activity_regularizer': None, 'kernel_constraint': None, 'recurrent_constraint': None, 'bias_constraint': None, 'dropout': 0.0, 'recurrent_dropout': 0.3, 'implementation': 2}}
mehtamansi29 commented 1 week ago

Hi @tiatariene-

Could you help me with the code snippet to reproduce the issue ?

tiatariene commented 1 week ago

To reproduce the issue, you will need to create two environnements using conda and install an "old" envrionnement and a "new" one.

Then you need to create and save a model using the old environnement

conda activate old_env
import tensorflow

inputs = tensorflow.keras.Input(shape=(25, 128))
x = tensorflow.keras.layers.Bidirectional(
    tensorflow.keras.layers.LSTM(64, return_sequences=True), name="lstm00"
)(inputs)
x = tensorflow.keras.layers.Bidirectional(
    tensorflow.keras.layers.LSTM(64, return_sequences=True), name="lstm01"
)(x)
model = tensorflow.keras.Model(inputs, x)
model.save("model.h5")

Then you try to load it in the new environnement and the bug should happen

conda activate new_env
import keras

keras.models.load_model('model.h5', compile=False)
mehtamansi29 commented 4 days ago

Hi @tiatariene -

Thanks for the code snippet. The error you are getting because model.save() and keras.models.load_model() is no longer support in keras2. So you need to upgrade keras version from keras2 to keras3 in old_env.

In keras3, for model savng can use keras.saving.save_model(model, filepath, overwrite=True, kwargs) and for loading model in new_env can use keras.saving.load_model(filepath, custom_objects=None, compile=True, safe_mode=True)**

Here you can fine more details regarding model saving and loading in keras3.

tiatariene commented 4 days ago

Hi,

Thank you for your answer.

It seems that you haven't understood my problem.

To be more precise, I have a set of models that have been developped in keras 2.3 or 2.4 that I want to load in newer keras versions like keras3. I cannot retrain them using newer keras versions.

When I load a model that doesn't contain biLSTMs, the function keras.saving.load_model(filepath, custom_objects=None, compile=True, safe_mode=True) works when loading a model saved using keras 2.3.

However when I load a model that contains biLSTM it fails by printing this error, which I think is related with the fact that LSTMs weights are saved differently in newer keras versions.

  File "/home/miniconda3/envs/main_env/lib/python3.11/site-packages/keras/src/saving/saving_api.py", line 183, in load_model
    return legacy_h5_format.load_model_from_hdf5(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/miniconda3/envs/main_env/lib/python3.11/site-packages/keras/src/legacy/saving/legacy_h5_format.py", line 133, in load_model_from_hdf5
    model = saving_utils.model_from_config(
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/miniconda3/envs/main_env/lib/python3.11/site-packages/keras/src/legacy/saving/saving_utils.py", line 85, in model_from_config
    return serialization.deserialize_keras_object(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/miniconda3/envs/main_env/lib/python3.11/site-packages/keras/src/legacy/saving/serialization.py", line 495, in deserialize_keras_object
    deserialized_obj = cls.from_config(
                       ^^^^^^^^^^^^^^^^
  File "/home/miniconda3/envs/main_env/lib/python3.11/site-packages/keras/src/models/model.py", line 517, in from_config
    return functional_from_config(
           ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/miniconda3/envs/main_env/lib/python3.11/site-packages/keras/src/models/functional.py", line 517, in functional_from_config
    process_layer(layer_data)
  File "/home/miniconda3/envs/main_env/lib/python3.11/site-packages/keras/src/models/functional.py", line 497, in process_layer
    layer = saving_utils.model_from_config(
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/miniconda3/envs/main_env/lib/python3.11/site-packages/keras/src/legacy/saving/saving_utils.py", line 85, in model_from_config
    return serialization.deserialize_keras_object(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/miniconda3/envs/main_env/lib/python3.11/site-packages/keras/src/legacy/saving/serialization.py", line 495, in deserialize_keras_object
    deserialized_obj = cls.from_config(
                       ^^^^^^^^^^^^^^^^
  File "/home/dxbz2376/miniconda3/envs/main_env/lib/python3.11/site-packages/keras/src/layers/rnn/bidirectional.py", line 314, in from_config
    config["layer"] = serialization_lib.deserialize_keras_object(
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/miniconda3/envs/main_env/lib/python3.11/site-packages/keras/src/saving/serialization_lib.py", line 694, in deserialize_keras_object
    cls = _retrieve_class_or_fn(
          ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/miniconda3/envs/main_env/lib/python3.11/site-packages/keras/src/saving/serialization_lib.py", line 812, in _retrieve_class_or_fn
    raise TypeError(
TypeError: Could not locate class 'LSTM'. Make sure custom classes are decorated with `@keras.saving.register_keras_serializable()`. Full object config: {'class_name': 'LSTM', 'config': {'name': 'lstm_1', 'trainable': True, 'dtype': 'float32', 'return_sequences': True, 'return_state': False, 'go_backwards': False, 'stateful': False, 'unroll': False, 'units': 64, 'activation': 'tanh', 'recurrent_activation': 'sigmoid', 'use_bias': True, 'kernel_initializer': {'class_name': 'VarianceScaling', 'config': {'scale': 1.0, 'mode': 'fan_avg', 'distribution': 'uniform', 'seed': None}}, 'recurrent_initializer': {'class_name': 'Orthogonal', 'config': {'gain': 1.0, 'seed': None}}, 'bias_initializer': {'class_name': 'Zeros', 'config': {}}, 'unit_forget_bias': True, 'kernel_regularizer': None, 'recurrent_regularizer': None, 'bias_regularizer': None, 'activity_regularizer': None, 'kernel_constraint': None, 'recurrent_constraint': None, 'bias_constraint': None, 'dropout': 0.0, 'recurrent_dropout': 0.3, 'implementation': 2}}

The end of the messages says in particular :

Could not locate class 'LSTM'. Make sure custom classes are decorated with @keras.saving.register_keras_serializable(). Full object config: {'class_name': 'LSTM', 'config': {'name': 'lstm_1', 'trainable': True, 'dtype': 'float32', 'return_sequences': True, 'return_state': False, 'go_backwards': False, 'stateful': False, 'unroll': False, 'units': 64, 'activation': 'tanh', 'recurrent_activation': 'sigmoid', 'use_bias': True, 'kernel_initializer': {'class_name': 'VarianceScaling', 'config': {'scale': 1.0, 'mode': 'fan_avg', 'distribution': 'uniform', 'seed': None}}, 'recurrent_initializer': {'class_name': 'Orthogonal', 'config': {'gain': 1.0, 'seed': None}}, 'bias_initializer': {'class_name': 'Zeros', 'config': {}}, 'unit_forget_bias': True, 'kernel_regularizer': None, 'recurrent_regularizer': None, 'bias_regularizer': None, 'activity_regularizer': None, 'kernel_constraint': None, 'recurrent_constraint': None, 'bias_constraint': None, 'dropout': 0.0, 'recurrent_dropout': 0.3, 'implementation': 2}}

This points to the fact keras3 load_model doesn't recognize the LSTM layer saved using keras2 save_model.

As I said, I cannot retrain this model trained in keras 2.3 and containing biLSTMs. But I would like to load it using keras3.

I hope you understand my problem better. Thank you