keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
61.06k stars 19.35k forks source link

TypeError: Could not locate class 'adam'. Make sure custom classes are decorated with `@keras.saving.register_keras_serializable()` #19661

Closed mpetteno closed 2 weeks ago

mpetteno commented 2 weeks ago

Hi everyone,

I think there is a problem with the loading of a model that has been compiled providing the optimizer as a dict. This does not happen if optimizer="adam" or optimizer=keras.optimizers.Adam()

Here the code that reproduces the issue:

model = keras.models.Sequential()
model.add(keras.layers.Dense(64, input_dim=3, activation='relu'))
model.add(keras.layers.Dense(32, activation='relu'))
model.add(keras.layers.Dense(1, activation='linear'))

model.compile(
    optimizer={
      "class_name": "Adam",
      "config": {
        "learning_rate": 0.01,
        "beta_1": 0.9,
        "beta_2": 0.999,
        "epsilon": 1e-7
      }
    },
    loss='mse'
)

keras.saving.save_model(model, 'model.keras')
loaded_model = keras.saving.load_model("model.keras")

The full traceback is:

Traceback (most recent call last): File "/venv/lib/python3.11/site-packages/keras/src/saving/saving_lib.py", line 152, in load_model return _load_model_from_fileobj( ^^^^^^^^^^^^^^^^^^^^^^^^^ File "/venv/lib/python3.11/site-packages/keras/src/saving/saving_lib.py", line 170, in _load_model_from_fileobj model = deserialize_keras_object( ^^^^^^^^^^^^^^^^^^^^^^^^^ File "/venv/lib/python3.11/site-packages/keras/src/saving/serialization_lib.py", line 734, in deserialize_keras_object instance.compile_from_config(compile_config) File "/venv/lib/python3.11/site-packages/keras/src/trainers/trainer.py", line 870, in compile_from_config config = serialization_lib.deserialize_keras_object(config) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/venv/lib/python3.11/site-packages/keras/src/saving/serialization_lib.py", line 594, in deserialize_keras_object return { ^ File "/venv/lib/python3.11/site-packages/keras/src/saving/serialization_lib.py", line 595, in key: deserialize_keras_object( ^^^^^^^^^^^^^^^^^^^^^^^^^ File "/venv/lib/python3.11/site-packages/keras/src/saving/serialization_lib.py", line 694, in deserialize_keras_object cls = _retrieve_class_or_fn( ^^^^^^^^^^^^^^^^^^^^^^ File "/venv/lib/python3.11/site-packages/keras/src/saving/serialization_lib.py", line 812, in _retrieve_class_or_fn raise TypeError( TypeError: Could not locate class 'adam'. Make sure custom classes are decorated with @keras.saving.register_keras_serializable(). Full object config: {'class_name': 'adam', 'config': {'learning_rate': 0.01, 'beta_1': 0.9, 'beta_2': 0.999, 'epsilon': 1e-07}}

I think that this is due to the fact that in this case the class_name field is serialized as "adam" (lowe case) and not "Adam" (capitalized) and thus in serialization_lib.py at line 803 obj is no resolved.

Thanks for your help.

fchollet commented 2 weeks ago

Thanks for the report, this is fixed at HEAD. Note that passing optimizers as dicts isn't an officially supported API (officially supported APIs are to pass it as a string or as an Optimizer instance).