keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
61.73k stars 19.44k forks source link

Loading weights into custom LSTM layer fails: Layer 'lstm_cell' expected 3 variables, but received 0 variables during loading. Expected: ['kernel', 'recurrent_kernel', 'bias'] #20322

Open lbortolotti opened 3 days ago

lbortolotti commented 3 days ago

I'm using the official TF 2.17 container (tensorflow/tensorflow:2.17.0-gpu-jupyter) + keras==3.5.0.

The following code saves a model which contains a (dummy) custom LSTM layer, then inits a new copy of the model (with a vanilla LSTM) and tries to load the weights from the first model into the second.

Code:

import os

os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

import keras
from keras import layers

# An extremely uninteresting custom layer
@keras.saving.register_keras_serializable()
class MyCustomLSTM(keras.layers.LSTM):
    def __init__(self, units, **kwargs):
        super().__init__(units, **kwargs)

def make_model(
    use_custom_lstm=True,
):
    inputs = layers.Input(shape=(None, 4), name="inputs")

    if use_custom_lstm:
        lstm = MyCustomLSTM
    else:
        lstm = layers.LSTM

    outputs = lstm(
        units=8,
        return_sequences=True,
        name="my_LSTM",
    )(inputs)

    model = keras.models.Model(inputs=inputs, outputs=outputs)

    return model

weights_file = "this_is_a_test.weights.h5"
if os.path.exists(weights_file):
    os.remove(weights_file)

model = make_model(use_custom_lstm=True)
model.compile()
model.save_weights(weights_file)

new_model = make_model(use_custom_lstm=False)
new_model.load_weights(weights_file)

Output:

Traceback (most recent call last):
  File "scratch_1.py", line 45, in <module>
    new_model.load_weights(weights_file)
  File "venv/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py", line 122, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "venv/lib/python3.11/site-packages/keras/src/saving/saving_lib.py", line 593, in _raise_loading_failure
    raise ValueError(msg)
ValueError: A total of 1 objects could not be loaded. Example error message for object <LSTMCell name=lstm_cell, built=True>:

Layer 'lstm_cell' expected 3 variables, but received 0 variables during loading. Expected: ['kernel', 'recurrent_kernel', 'bias']

List of objects that could not be loaded:
[<LSTMCell name=lstm_cell, built=True>]

Considering that the custom layer in this case is doing absolutely nothing of interest, I assume this is a bug. If not, please let me know how one is meant to wrap a LSTM layer to avoid this issue.

Thanks!

mehtamansi29 commented 21 hours ago

Hi @lbortolotti -

Thanks for reporting the issue. Here as per your code while saving the model weights you used use_custom_lstm=True in this line model = make_model(use_custom_lstm=True). But while loading the model used use_custom_lstm=False, because of this facing an error.

By taking model.summary() while saving weights will gives better idea which layers and paramters are used for model creation.

Using below changes will resolve the error.

new_model = make_model(use_custom_lstm=True)
new_model.load_weights(weights_file)

Attached gist for your reference as well.

lbortolotti commented 21 hours ago

Hi. I'm perfectly aware that the model I'm loading the weights into has a "different" lstm layer. However, the models have 1) identical weight structure/dimensions and 2) identically named layers. In this situation, keras has always supported loading weights, even if the layer class has changed. This is something I've used all the time, and continues to work with tf-keras.

My example is particularly extreme as MyCustomLSTM is absolutely identical to a vanilla LSTM layer - normally I'd have some custom logic in there (but nothing that affects the weight structure).

The official doc confirms that my example should work, I think, as it says Weights are loaded based on the network's topology. :

https://keras.io/api/models/model_saving_apis/weights_saving_and_loading/#loadweights-method

mehtamansi29 commented 18 hours ago

Hi @lbortolotti -

As per the load_weights method, Weights are loaded based on the network's topology but architecture should be the same as when the weights were saved(as per the load_weight document).

So while saving weights, model architecture has InputLayer+ MyCustomLSTM layer(based on model summary) created by use_custom_lstm=True argument. So while loading the weights, model creation should be same use_custom_lstm=True argument to get same architecture.

Attaching gist where running with keras2(tf-keras) and it is giving same error with new_model = make_model(use_custom_lstm=False)

lbortolotti commented 13 hours ago

Interesting. I've dug a bit deeper, and have found that to restore model loading functionality I have to change the file name from .weights.h5 to .h5. Literally just removing the ".weights." suffix (which is now a strict requirement in TF 2.17, as far as I can tell) resolves it, and I can transfer weights as I always have.

To replicate, you just need to take your last gist, save the weights as weights_file = "this_is_a_test.h5", and you'll find that they then load correctly.

Which behaviour is the expected behaviour? I was definitely using the functionality...