nengo / keras-lmu

Keras implementation of Legendre Memory Units
https://www.nengo.ai/keras-lmu/
Other
207 stars 35 forks source link

Error when loading model with tf.keras.models.load_model #37

Open bjkomer opened 3 years ago

bjkomer commented 3 years ago

Minimal example to reproduce the issue: This was tested with tf version 2.4.1 and keras_lmu version 0.3.1

import tensorflow as tf
import keras_lmu
import numpy as np
from tensorflow.keras.layers import Input
from tensorflow.keras.models import Model

inp = Input(shape=(None, 10))

out = keras_lmu.LMU(
    memory_d=1,
    order=4,
    theta=5,
    hidden_cell=tf.keras.layers.LSTMCell(units=50)
)(inp)

model = Model(inputs=inp, outputs=out)

optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)
loss = tf.keras.losses.BinaryCrossentropy()

model.compile(
    optimizer=optimizer,
    loss=loss,
    metrics=loss,
)

history = model.fit(
    np.zeros((32, 16, 10)),
    np.zeros((32, 50)),
    steps_per_epoch=1,
    epochs=2,
)

model.save("my_model")

del model

model = tf.keras.models.load_model(
    "my_model", custom_objects={"LMU": keras_lmu.LMU},
    compile=False
)

Note that if model.fit is not called, the model will be loaded fine.

The error that occurs is:

Traceback (most recent call last):
  File "loading_minimal_example.py", line 39, in <module>
    model = tf.keras.models.load_model(
  File "/home/bjkomer/anaconda3/envs/myenv/lib/python3.8/site-packages/tensorflow/python/keras/saving/save.py", line 212, in load_model
    return saved_model_load.load(filepath, compile, options)
  File "/home/bjkomer/anaconda3/envs/myenv/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/load.py", line 138, in load
    keras_loader.load_layers(compile=compile)
  File "/home/bjkomer/anaconda3/envs/myenv/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/load.py", line 380, in load_layers
    self.loaded_nodes[node_metadata.node_id] = self._load_layer(
  File "/home/bjkomer/anaconda3/envs/myenv/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/load.py", line 417, in _load_layer
    obj, setter = self._revive_from_config(identifier, metadata, node_id)
  File "/home/bjkomer/anaconda3/envs/myenv/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/load.py", line 431, in _revive_from_config
    obj = self._revive_metric_from_config(metadata)
  File "/home/bjkomer/anaconda3/envs/myenv/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/load.py", line 540, in _revive_metric_from_config
    obj = metrics.deserialize(
  File "/home/bjkomer/anaconda3/envs/myenv/lib/python3.8/site-packages/tensorflow/python/keras/metrics.py", line 3446, in deserialize
    return deserialize_keras_object(
  File "/home/bjkomer/anaconda3/envs/myenv/lib/python3.8/site-packages/tensorflow/python/keras/utils/generic_utils.py", line 346, in deserialize_keras_object
    (cls, cls_config) = class_and_config_for_serialized_keras_object(
  File "/home/bjkomer/anaconda3/envs/myenv/lib/python3.8/site-packages/tensorflow/python/keras/utils/generic_utils.py", line 311, in class_and_config_for_serialized_keras_object
    deserialized_objects[key] = deserialize_keras_object(
  File "/home/bjkomer/anaconda3/envs/myenv/lib/python3.8/site-packages/tensorflow/python/keras/utils/generic_utils.py", line 360, in deserialize_keras_object
    return cls.from_config(cls_config)
  File "/home/bjkomer/anaconda3/envs/myenv/lib/python3.8/site-packages/tensorflow/python/keras/metrics.py", line 642, in from_config
    return super(MeanMetricWrapper, cls).from_config(config)
  File "/home/bjkomer/anaconda3/envs/myenv/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_layer.py", line 720, in from_config
    return cls(**config)
TypeError: __init__() got an unexpected keyword argument 'reduction'
arvoelke commented 3 years ago

I get the same error if I change it to out = tf.keras.layers.LSTM(50)(inp) so this might be a TensorFlow-Keras issue instead?

bjkomer commented 3 years ago

Yea, looks like it is actually an issue on their end, I opened one on there: https://github.com/tensorflow/tensorflow/issues/48235 I also noticed if you don't include metrics it works fine.