keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
62k stars 19.48k forks source link

Issue loading model containing Dense layer with Identity initializer #20483

Closed kristoferm94 closed 1 day ago

kristoferm94 commented 1 day ago

Hi, I've been migrating some model code from Keras 2 to Keras 3, and I think I stumbled upon a bug.

I've noticed that if I save a model containing a dense layer with an identity kernel initializer in Keras 3+, I cannot reload the model. I get an exception that says Keras cannot interpret the initializer identifier in config.json in the Keras model file. A snippet from the exception (rest at the bottom of this post):

Exception encountered: Could not interpret initializer identifier: {'module': 'keras.initializers', 'class_name': 'IdentityInitializer', 'config': {}, 'registered_name': None}

I have tried this out on Windows using Keras 3.6 + Jax 0.4.35 and Ubuntu using Keras 3.6 + Tensorflow 2.18.0, and I get the same exception.

Currently, I am using a hacky workaround to replace the faulty 'IdentityInitializer' string in the config.json in the Keras model file with 'Identity'

Here are pytest tests for replicating this issue (see test_save_read_dense_layer_model_with_identity_initializer for replicating this exception, see attached testoutput.txt file for test output which includes the full exception traceback):

from pathlib import Path
from tempfile import TemporaryDirectory
from zipfile import ZipFile

import keras

def load_model_with_workaround(model_path: Path) -> keras.Model:
    with (
        TemporaryDirectory() as tmp_dir,
        ZipFile(model_path, "r") as original_model_file,
    ):
        new_model_path = Path(tmp_dir) / "new.keras"
        with ZipFile(new_model_path, "w") as new_model_file:
            for file_name in original_model_file.namelist():
                original_data = original_model_file.read(file_name)

                if file_name == "config.json":
                    original_data = (
                        original_data.decode("utf-8")
                        .replace(
                            'class_name": "IdentityInitializer"',
                            'class_name": "Identity"',
                        )
                        .encode("utf-8")
                    )

                with new_model_file.open(file_name, "w") as f:
                    f.write(original_data)

        return keras.models.load_model(new_model_path)

# This test will fail
def test_save_read_dense_layer_model_with_identity_initializer() -> None:
    model = keras.Sequential(
        [
            keras.layers.Input((5,)),
            keras.layers.Dense(5, kernel_initializer=keras.initializers.Identity()),
        ]
    )
    with TemporaryDirectory() as tmp_dir:
        save_path = Path(tmp_dir) / "mymodel.keras"
        model.save(save_path)
        model_from_file = keras.models.load_model(save_path)

# This test will pass
def test_save_read_dense_layer_model_with_identity_initializer_using_workaround() -> None:
    model = keras.Sequential(
        [
            keras.layers.Input((5,)),
            keras.layers.Dense(5, kernel_initializer=keras.initializers.Identity()),
        ]
    )
    with TemporaryDirectory() as tmp_dir:
        save_path = Path(tmp_dir) / "mymodel.keras"
        model.save(save_path)
        model_from_file = load_model_with_workaround(save_path)
james77777778 commented 1 day ago

I have submitted a PR for this issue here: #20484

You can also patch keras.src.initializers.ALL_OBJECTS_DICT in your code to address it.

import pathlib
import tempfile

import keras
from keras.src import initializers

model = keras.Sequential(
    [
        keras.layers.Input((5,)),
        keras.layers.Dense(
            5, kernel_initializer=keras.initializers.IdentityInitializer()
        ),
    ]
)
initializers.ALL_OBJECTS_DICT["IdentityInitializer"] = initializers.Identity  # <--

with tempfile.TemporaryDirectory() as tmp_dir:
    save_path = pathlib.Path(tmp_dir) / "mymodel.keras"
    model.save(save_path)
    model_from_file = keras.models.load_model(save_path)
google-ml-butler[bot] commented 1 day ago

Are you satisfied with the resolution of your issue? Yes No