tensorflow / addons

Useful extra functionality for TensorFlow 2.x maintained by SIG-addons
Apache License 2.0
1.69k stars 613 forks source link

Loading model saved with CyclicalLearningRate fails "TypeError: argument "scale_fn" must be a callable" #2728

Closed lambda-science closed 2 years ago

lambda-science commented 2 years ago

System information

Describe the bug

When saving a model to .h5 format with a Cyclical learning rate ( tfa.optimizers.CyclicalLearningRate), loading the model fails with an error "TypeError: argument "scale_fn" must be a callable"

Code to reproduce the issue

Training a model with a cyclical learning rate and trying to load the model afterward.

# With Min Max LR Train real model
MIN_LR = 0.00001
MAX_LR = 0.00003

steps_per_epoch = len(train_images) // 32 # Batch size is 32
clr = tfa.optimizers.CyclicalLearningRate(initial_learning_rate=MIN_LR,
    maximal_learning_rate=MAX_LR,
    scale_fn=lambda x: 1.0, # Triangular Scaling Method
    # scale_fn=lambda x: 1/(2.**(x-1)), # Triangular2 Scaling method
    step_size= 4 * steps_per_epoch
)

checkpoint_cb = callbacks.ModelCheckpoint(BASE_FOLDER+"results/"+MODEL_NAME+"_model.h5", save_best_only=True)
early_stopping_cb = callbacks.EarlyStopping(patience=10, restore_best_weights=True)

model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=clr),
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

history = model.fit(train_images, train_labels, epochs=500, batch_size=32,
                    validation_data=(val_images, val_labels), shuffle=True, class_weight=class_weights, 
                    callbacks=[checkpoint_cb, early_stopping_cb, tensorboard_cb])

model = models.load_model(BASE_FOLDER+"results/"+MODEL_NAME+"_model.h5")

with open(BASE_FOLDER+"results/"+MODEL_NAME+"_history.pickle", 'wb') as file_pi:
    pickle.dump(history.history, file_pi)

Other info / logs

Traceback:

Epoch 231/500
118/118 [==============================] - 11s 90ms/step - loss: 0.4837 - accuracy: 0.8937 - val_loss: 0.6672 - val_accuracy: 0.8831

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
[<ipython-input-13-0199af2caf8d>](https://localhost:8080/#) in <module>()
     22                     callbacks=[checkpoint_cb, early_stopping_cb, tensorboard_cb])
     23 
---> 24 model = models.load_model(BASE_FOLDER+"results/"+MODEL_NAME+"_model.h5")
     25 
     26 with open(BASE_FOLDER+"results/"+MODEL_NAME+"_history.pickle", 'wb') as file_pi:

5 frames
[/usr/local/lib/python3.7/dist-packages/keras/utils/traceback_utils.py](https://localhost:8080/#) in error_handler(*args, **kwargs)
     65     except Exception as e:  # pylint: disable=broad-except
     66       filtered_tb = _process_traceback_frames(e.__traceback__)
---> 67       raise e.with_traceback(filtered_tb) from None
     68     finally:
     69       del filtered_tb

[/usr/local/lib/python3.7/dist-packages/typeguard/__init__.py](https://localhost:8080/#) in wrapper(*args, **kwargs)
    807     def wrapper(*args, **kwargs):
    808         memo = _CallMemo(python_func, _localns, args=args, kwargs=kwargs)
--> 809         check_argument_types(memo)
    810         retval = func(*args, **kwargs)
    811         check_return_type(retval, memo)

[/usr/local/lib/python3.7/dist-packages/typeguard/__init__.py](https://localhost:8080/#) in check_argument_types(memo)
    668                 check_type(description, value, expected_type, memo)
    669             except TypeError as exc:  # suppress unnecessarily long tracebacks
--> 670                 raise exc from None
    671 
    672     return True

[/usr/local/lib/python3.7/dist-packages/typeguard/__init__.py](https://localhost:8080/#) in check_argument_types(memo)
    666             description = 'argument "{}"'.format(argname)
    667             try:
--> 668                 check_type(description, value, expected_type, memo)
    669             except TypeError as exc:  # suppress unnecessarily long tracebacks
    670                 raise exc from None

[/usr/local/lib/python3.7/dist-packages/typeguard/__init__.py](https://localhost:8080/#) in check_type(argname, value, expected_type, memo)
    564         checker_func = origin_type_checkers.get(origin_type)
    565         if checker_func:
--> 566             checker_func(argname, value, expected_type, memo)
    567         else:
    568             check_type(argname, value, origin_type, memo)

[/usr/local/lib/python3.7/dist-packages/typeguard/__init__.py](https://localhost:8080/#) in check_callable(argname, value, expected_type, memo)
    229 def check_callable(argname: str, value, expected_type, memo: Optional[_CallMemo]) -> None:
    230     if not callable(value):
--> 231         raise TypeError('{} must be a callable'.format(argname))
    232 
    233     if expected_type.__args__:

TypeError: argument "scale_fn" must be a callable
lambda-science commented 2 years ago

Nevermind, duplicate to https://github.com/tensorflow/addons/issues/2380 I'll close for now to test the solution.