keras-team / tf-keras

The TensorFlow-specific implementation of the Keras API, which was the default Keras from 2019 to 2023.
Apache License 2.0
64 stars 31 forks source link

ModelCheckpoint used with `save_best_only` doesn't handle interruptions, even with BackupAndRestore #430

Open nicolasnn opened 2 years ago

nicolasnn commented 2 years ago

System information.

Describe the problem. I am using ModelCheckpoint callback to save the best model, combined with BackupAndRestore callback to handle interruptions. The problem lies when running again a training script after an interruption. The model restored by BackupAndRetore doesn't have the previous value of losses and metrics. Thus, ModelCheckpoint saves the model on the 1st epoch of this new run, whatever the value of loss, it even overwrites the "best" model with a not-as-good model.

Describe the current behavior.

Describe the expected behavior.

Standalone code to reproduce the issue.

First training

import tensorflow as tf
from tensorflow import keras
import numpy as np

# Dummy datasets
np.random.seed(12)
train_ds = tf.data.Dataset.from_tensor_slices((np.random.rand(500, 10, 4), np.random.randint(0, 5, (500, 1))))
valid_ds = tf.data.Dataset.from_tensor_slices((np.random.rand(100, 10, 4), np.random.randint(0, 5, (100, 1))))

model = keras.Sequential(
    [
        keras.layers.Dense(40, activation="relu"),
        keras.layers.Dense(100, activation="relu"),
        keras.layers.Dense(400, activation="relu"),
        keras.layers.Dense(10, activation="relu"),
        keras.layers.Dense(3, activation="relu"),
        keras.layers.Dense(1),
    ]
)

model.compile(loss='mean_squared_error',
              optimizer=tf.keras.optimizers.RMSprop(learning_rate=0.005))

# Callbacks
backup_cb = keras.callbacks.BackupAndRestore(backup_dir='/tmp/backup_dir')
ckpt_cb = keras.callbacks.ModelCheckpoint('/tmp/best_model', save_best_only=True, monitor='val_loss', verbose=1)

# Callback that fakes an interruption
class InterruptingCallback(keras.callbacks.Callback):
    def on_epoch_begin(self, epoch, logs=None):
        if epoch == 4:
            raise RuntimeError('Interrupting!')

model.fit(train_ds, epochs=6, validation_data=valid_ds, verbose=1, callbacks=[backup_cb, ckpt_cb, InterruptingCallback()])

The output is:

Epoch 1/6
476/500 [===========================>..] - ETA: 0s - loss: 2.4351  
Epoch 1: val_loss improved from inf to 2.07992, saving model to /tmp/best_model
500/500 [==============================] - 2s 3ms/step - loss: 2.4528 - val_loss: 2.0799
Epoch 2/6
475/500 [===========================>..] - ETA: 0s - loss: 2.2506
Epoch 2: val_loss did not improve from 2.07992
500/500 [==============================] - 1s 1ms/step - loss: 2.2690 - val_loss: 2.0819
Epoch 3/6
476/500 [===========================>..] - ETA: 0s - loss: 2.2212
Epoch 3: val_loss did not improve from 2.07992
500/500 [==============================] - 1s 1ms/step - loss: 2.2440 - val_loss: 2.0859
Epoch 4/6
486/500 [============================>.] - ETA: 0s - loss: 2.2116
Epoch 4: val_loss did not improve from 2.07992
500/500 [==============================] - 1s 1ms/step - loss: 2.2372 - val_loss: 2.0894
Traceback (most recent call last):
  File "1st_training.py", line 36, in <module>
    model.fit(train_ds, epochs=6, validation_data=valid_ds, verbose=1, callbacks=[backup_cb, ckpt_cb, InterruptingCallback()])
  File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 70, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "1st_training.py", line 33, in on_epoch_begin
    raise RuntimeError('Interrupting!')
RuntimeError: Interrupting!

2nd training

import tensorflow as tf
from tensorflow import keras
import numpy as np

# Dummy datasets
np.random.seed(12)
train_ds = tf.data.Dataset.from_tensor_slices((np.random.rand(500, 10, 4), np.random.randint(0, 5, (500, 1))))
valid_ds = tf.data.Dataset.from_tensor_slices((np.random.rand(100, 10, 4), np.random.randint(0, 5, (100, 1))))

model = keras.Sequential(
    [
        keras.layers.Dense(40, activation="relu"),
        keras.layers.Dense(100, activation="relu"),
        keras.layers.Dense(400, activation="relu"),
        keras.layers.Dense(10, activation="relu"),
        keras.layers.Dense(3, activation="relu"),
        keras.layers.Dense(1),
    ]
)

model.compile(loss='mean_squared_error',
              optimizer=tf.keras.optimizers.RMSprop(learning_rate=0.005))

# Callbacks
backup_cb = keras.callbacks.BackupAndRestore(backup_dir='/tmp/backup_dir')
ckpt_cb = keras.callbacks.ModelCheckpoint('/tmp/best_model', save_best_only=True, monitor='val_loss', verbose=1)

model.fit(train_ds, epochs=6, validation_data=valid_ds, verbose=1, callbacks=[backup_cb, ckpt_cb])

The output is:

Epoch 5/6
476/500 [===========================>..] - ETA: 0s - loss: 2.2097  
Epoch 5: val_loss improved from inf to 2.09097, saving model to /tmp/best_model
500/500 [==============================] - 2s 3ms/step - loss: 2.2322 - val_loss: 2.0910
Epoch 6/6
500/500 [==============================] - ETA: 0s - loss: 2.2200
Epoch 6: val_loss did not improve from 2.09097
500/500 [==============================] - 1s 1ms/step - loss: 2.2200 - val_loss: 2.0978

The problem lies at val_loss improved from inf to 2.09097, the model restored by BackupAndRetore doesn't restore the previous value of val_loss. The model is initialized with an inf value, thus ModelCheckpoint doesn't fulfill what it is supposed to do and it even overwrites the "best" model with a not-as-good model.

tilakrayal commented 2 years ago

@gowthamkpr, I was able to reproduce the issue on tensorflow v2.8, v2.9 and nightly. Kindly find the gist of it here.

sampathweb commented 2 years ago

@nicolasnn - Thanks for reporting this issue with detailed examples.

I am working on an update to BackupAndRestore callback to address this specific scenario. I will submit a PR in the next few weeks. @rchao - please assign this issue to me.

rchao commented 2 years ago

Thanks Ramesh!

henrypinkard commented 1 year ago

I have run into the same issue. @sampathweb are you still planning to fix this? @rchao

hvgazula commented 6 months ago

Hello, just curious if this issue has been fixed or if it is still in the works?