keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
61.98k stars 19.48k forks source link

Saving and restoring the model is changing type of model inputs and outputs #19999

Closed doiko closed 3 days ago

doiko commented 3 months ago

Hi there, using: keras==3.4.1, tensorflow==2.17.0 The input_tensors parameter of clone_model should be in list after restoring a saved model. It was not like this for versions 2.15. Now the some function needs different parameter for fresh or restored models.

Saving and restoring the model is changing type of model inputs and outputs....

Minimal example: !pip install --upgrade keras==3.4.1

import numpy as np
import keras

from keras.models import Model, clone_model
from keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, Concatenate

def create_minimal_unet(input_shape):
    inputs = Input(input_shape, sparse=False)

    # Encoder
    c1 = Conv2D(32, (3, 3), activation='relu', padding='same')(inputs)
    p1 = MaxPooling2D((2, 2))(c1)

    # Bottleneck
    bn = Conv2D(64, (3, 3), activation='relu', padding='same')(p1)

    # Decoder
    up1 = UpSampling2D((2, 2))(bn)
    concat1 = Concatenate()([up1, c1])
    c2 = Conv2D(32, (3, 3), activation='relu', padding='same')(concat1)

    outputs = Conv2D(1, (1, 1), activation='sigmoid')(c2)  # For binary segmentation

    return Model(inputs=inputs, outputs=outputs)

# Create the U-Net model
model = create_minimal_unet((128, 128, 1))
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

# Generate a random input tensor of size 128x128
random_input = np.random.random((1, 128, 128, 1)).astype(np.float32)

predictions = model.predict(random_input)
assert predictions.shape == random_input.shape == (1, 128, 128, 1)

# Generate a random input tensor of size 256x128
random_input_fcn = np.random.random((1, 256, 128, 1)).astype(np.float32)

# Clone model to accept new size. Here input_tensors not in a List
cloned = clone_model(model, input_tensors=Input((256, 128, 1)))
predictions_fcn = cloned.predict(random_input_fcn)
assert predictions_fcn.shape == random_input_fcn.shape == (1, 256, 128, 1)

# Saving and restoring changes this behavior
model.save('/tmp/saved_model.keras')
saved_model = keras.models.load_model('/tmp/saved_model.keras')
print('\nInput changed from tensor to list')
print(f'model.input: {model.input}')
print(f'saved_model.input: {saved_model.input}')
print('\nOutput changed from tensor to list')
print(f'model.output: {model.output}')
print(f'saved_model.output: {saved_model.output}')
print("\nThe saved model can be cloned only if the input tensor is in a List")
clone_saved = clone_model(saved_model, input_tensors=[Input((256, 128, 1), sparse=False)])
predictions_fcn = clone_saved.predict(random_input_fcn)
assert predictions_fcn.shape == random_input_fcn.shape == (1, 256, 128, 1)

print("\nWhat was working in the fresh model (also in 2.15) above fails in the restored model")
clone_saved = clone_model(saved_model, input_tensors=Input((256, 128, 1), sparse=False))

output is:

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 191ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 104ms/step
/xxx/python3.11/site-packages/keras/src/saving/saving_lib.py:576: UserWarning: Skipping variable loading for optimizer 'adam', because it has 18 variables whereas the saved optimizer has 2 variables. 
  saveable.load_own_variables(weights_store.get(inner_path))
Input changed from tensor to list
model.input: <KerasTensor shape=(None, 128, 128, 1), dtype=float32, sparse=False, name=keras_tensor_271>
saved_model.input: [<KerasTensor shape=(None, 128, 128, 1), dtype=float32, sparse=False, name=input_layer_26>]
Output changed from tensor to list
model.output: <KerasTensor shape=(None, 128, 128, 1), dtype=float32, sparse=False, name=keras_tensor_278>
saved_model.output: [<KerasTensor shape=(None, 128, 128, 1), dtype=float32, sparse=False, name=keras_tensor_301>]
The saved model can be cloned only if the input tensor is in a List
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 125ms/step
What was working in the fresh model (also in 2.15) above fails in the restored model
Traceback (most recent call last):
  File "xxx/python3.11/site-packages/keras/src/models/cloning.py", line 374, in _clone_functional_model
    tree.assert_same_structure(input_tensors, model.input)
  File "xxx/python3.11/site-packages/keras/src/tree/tree_api.py", line 205, in assert_same_structure
    return tree_impl.assert_same_structure(a, b, check_types=check_types)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "xxx/python3.11/site-packages/keras/src/tree/optree_impl.py", line 96, in assert_same_structure
    raise ValueError(
ValueError: `a` and `b` don't have the same structure. Received: structure of a=PyTreeSpec(*, NoneIsLeaf), structure of b=PyTreeSpec([*], NoneIsLeaf)
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
  File "/snap/pycharm-professional/401/plugins/python/helpers/pydev/pydevconsole.py", line 364, in runcode
    coro = func()
           ^^^^^^
  File "<input>", line 61, in <module>
  File "xxx/lib/python3.11/site-packages/keras/src/models/cloning.py", line 182, in clone_model
    return _clone_functional_model(
           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "xxxx/python3.11/site-packages/keras/src/models/cloning.py", line 376, in _clone_functional_model
    raise ValueError(
ValueError: `input_tensors` must have the same structure as model.input
Reference structure: [<KerasTensor shape=(None, 128, 128, 1), dtype=float32, sparse=False, name=input_layer_26>]
Received structure: <KerasTensor shape=(None, 256, 128, 1), dtype=float32, sparse=False, name=keras_tensor_310>
ghsanti commented 3 months ago

Hi there,

saved_model = keras.models.load_model('/tmp/saved_model.keras', compile=True)

That seems to solve both issues.

Tagging @sachinprasadhs in case smone can double check.

doiko commented 3 months ago

Hi @ghsanti , @sachinprasadhs Issue is that both Input and Output types are changed. I see no difference when I add the compile=True Can you please confirm your suggestion.

ghsanti commented 3 months ago

@doiko interesting, indeed I just noticed that the install command I used is !pip install --upgrade keras-nightly

So it may have been fixed.

doiko commented 3 months ago

Indeed @ghsanti , @sachinprasadhs install --upgrade keras-nightly solved the issue. Any idea when this will be pushed to release? Models saved with this release will be tricky to handle in the future..

ghsanti commented 3 months ago

They may want to do a minor release idk. I'd guess that it's related to this commit.

Maybe @mehtamansi29 can help.

mehtamansi29 commented 1 month ago

Hi @doiko -

Here getting the error while cloning the model. Here while cloning the model no need to define in input_tensors as list. Remove the list from clone_model this line clone_saved = clone_model(saved_model, input_tensors=Input((256, 128, 1), sparse=False)) will resolve the error in keras3.6.0.

Attached gist for your reference.

github-actions[bot] commented 2 weeks ago

This issue is stale because it has been open for 14 days with no activity. It will be closed if no further activity occurs. Thank you.

github-actions[bot] commented 3 days ago

This issue was closed because it has been inactive for 28 days. Please reopen if you'd like to work on this further.

google-ml-butler[bot] commented 3 days ago

Are you satisfied with the resolution of your issue? Yes No