secondmind-labs / GPflux

Deep GPs built on top of TensorFlow/Keras and GPflow
https://secondmind-labs.github.io/GPflux/
Apache License 2.0
120 stars 24 forks source link

ModelCheckpoint with save_weights_only=False crashes #39

Open izsahara opened 3 years ago

izsahara commented 3 years ago

Describe the bug Literally the title

import numpy as np
import tensorflow as tf
from gpflux.helpers import (
    construct_basic_inducing_variables,
    construct_basic_kernel,
    construct_mean_function,
    construct_gp_layer
)
from gpflux.models import DeepGP
from gpflux.layers import LikelihoodLayer
from gpflow.kernels import SquaredExponential
from gpflow.likelihoods import Gaussian
from gpflow import set_trainable

def xiong_1d(XX: np.ndarray):
    return -0.5*(np.sin(40*(XX-0.85)**4) * np.cos(2.5*(XX-0.95)) + 0.5*(XX-0.9) + 1)

X_valid = np.linspace(0, 1, 1000).reshape(-1, 1)
Y_valid = xiong_1d(X_valid).reshape(-1, 1)
X = np.linspace(0, 1, 30).reshape(-1, 1)
Y = xiong_1d(X).reshape(-1, 1)

# # -------------- DGP MODEL -------------- #

n_inducing = int(len(X))

layer1 = construct_gp_layer(num_data=X.shape[0],
                            num_inducing=X.shape[0],
                            input_dim=X.shape[1],
                            output_dim=X.shape[1],
                            kernel_class=SquaredExponential,
                            )
layer2 = construct_gp_layer(num_data=X.shape[0],
                            num_inducing=X.shape[0],
                            input_dim=X.shape[1],
                            output_dim=Y.shape[1],
                            kernel_class=SquaredExponential,
                            )
gp_layers = [layer1, layer2]
likelihood = Gaussian(variance=1e-5)
set_trainable(likelihood.variance, False)

dgp_model = DeepGP(f_layers=gp_layers, likelihood=LikelihoodLayer(likelihood))
train_mode = dgp_model.as_training_model()
train_mode.compile(tf.optimizers.Adam(0.01))

# File path of script "E:/23620029-Faiz/.PROJECTS/AdaptiveDGP/demo"
checkpoint_filepath = "E:/23620029-Faiz/.PROJECTS/AdaptiveDGP/checkpoint"
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath,
    save_weights_only=False,
    monitor='loss',
    mode='max',
    save_best_only=True)

train_mode.fit({"inputs": X, "targets": Y}, epochs=5000, verbose=1,  callbacks=[model_checkpoint_callback])`

The Error/Console Output

Epoch 1/5000
1/1 [==============================] - 4s 4s/step - loss: 59876.1801 - gp_layer_prior_kl: 0.0000e+00 - gp_layer_1_prior_kl: 0.0000e+00
Traceback (most recent call last):
  File "<input>", line 1, in <module>
  File "E:\23620029-Faiz\PyCharm\PyCharm Community Edition 2021.1.1\plugins\python-ce\helpers\pydev\_pydev_bundle\pydev_umd.py", line 197, in runfile
    pydev_imports.execfile(filename, global_vars, local_vars)  # execute the script
  File "E:\23620029-Faiz\PyCharm\PyCharm Community Edition 2021.1.1\plugins\python-ce\helpers\pydev\_pydev_imps\_pydev_execfile.py", line 18, in execfile
    exec(compile(contents+"\n", file, 'exec'), glob, loc)
  File "E:/23620029-Faiz/.PROJECTS/AdaptiveDGP/demo/test_gpflux.py", line 181, in <module>
    train_mode.fit({"inputs": X, "targets": Y}, epochs=5000, verbose=1,  callbacks=[model_checkpoint_callback])
  File "E:\23620029-Faiz\.PROJECTS\AdaptiveDGP\venv\lib\site-packages\tensorflow\python\keras\engine\training.py", line 1145, in fit
    callbacks.on_epoch_end(epoch, epoch_logs)
  File "E:\23620029-Faiz\.PROJECTS\AdaptiveDGP\venv\lib\site-packages\tensorflow\python\keras\callbacks.py", line 428, in on_epoch_end
    callback.on_epoch_end(epoch, logs)
  File "E:\23620029-Faiz\.PROJECTS\AdaptiveDGP\venv\lib\site-packages\tensorflow\python\keras\callbacks.py", line 1344, in on_epoch_end
    self._save_model(epoch=epoch, logs=logs)
  File "E:\23620029-Faiz\.PROJECTS\AdaptiveDGP\venv\lib\site-packages\tensorflow\python\keras\callbacks.py", line 1396, in _save_model
    self.model.save(filepath, overwrite=True, options=self._options)
  File "E:\23620029-Faiz\.PROJECTS\AdaptiveDGP\venv\lib\site-packages\tensorflow\python\keras\engine\training.py", line 2001, in save
    save.save_model(self, filepath, overwrite, include_optimizer, save_format,
  File "E:\23620029-Faiz\.PROJECTS\AdaptiveDGP\venv\lib\site-packages\tensorflow\python\keras\saving\save.py", line 156, in save_model
    saved_model_save.save(model, filepath, overwrite, include_optimizer,
  File "E:\23620029-Faiz\.PROJECTS\AdaptiveDGP\venv\lib\site-packages\tensorflow\python\keras\saving\saved_model\save.py", line 89, in save
    save_lib.save(model, filepath, signatures, options)
  File "E:\23620029-Faiz\.PROJECTS\AdaptiveDGP\venv\lib\site-packages\tensorflow\python\saved_model\save.py", line 1032, in save
    _, exported_graph, object_saver, asset_info = _build_meta_graph(
  File "E:\23620029-Faiz\.PROJECTS\AdaptiveDGP\venv\lib\site-packages\tensorflow\python\saved_model\save.py", line 1198, in _build_meta_graph
    return _build_meta_graph_impl(obj, signatures, options, meta_graph_def)
  File "E:\23620029-Faiz\.PROJECTS\AdaptiveDGP\venv\lib\site-packages\tensorflow\python\saved_model\save.py", line 1132, in _build_meta_graph_impl
    signatures = signature_serialization.find_function_to_export(
  File "E:\23620029-Faiz\.PROJECTS\AdaptiveDGP\venv\lib\site-packages\tensorflow\python\saved_model\signature_serialization.py", line 75, in find_function_to_export
    functions = saveable_view.list_functions(saveable_view.root)
  File "E:\23620029-Faiz\.PROJECTS\AdaptiveDGP\venv\lib\site-packages\tensorflow\python\saved_model\save.py", line 150, in list_functions
    obj_functions = obj._list_functions_for_serialization(  # pylint: disable=protected-access
  File "E:\23620029-Faiz\.PROJECTS\AdaptiveDGP\venv\lib\site-packages\tensorflow\python\keras\engine\training.py", line 2612, in _list_functions_for_serialization
    functions = super(
  File "E:\23620029-Faiz\.PROJECTS\AdaptiveDGP\venv\lib\site-packages\tensorflow\python\keras\engine\base_layer.py", line 3086, in _list_functions_for_serialization
    return (self._trackable_saved_model_saver
  File "E:\23620029-Faiz\.PROJECTS\AdaptiveDGP\venv\lib\site-packages\tensorflow\python\keras\saving\saved_model\base_serialization.py", line 94, in list_functions_for_serialization
    fns = self.functions_to_serialize(serialization_cache)
  File "E:\23620029-Faiz\.PROJECTS\AdaptiveDGP\venv\lib\site-packages\tensorflow\python\keras\saving\saved_model\layer_serialization.py", line 78, in functions_to_serialize
    return (self._get_serialized_attributes(
  File "E:\23620029-Faiz\.PROJECTS\AdaptiveDGP\venv\lib\site-packages\tensorflow\python\keras\saving\saved_model\layer_serialization.py", line 94, in _get_serialized_attributes
    object_dict, function_dict = self._get_serialized_attributes_internal(
  File "E:\23620029-Faiz\.PROJECTS\AdaptiveDGP\venv\lib\site-packages\tensorflow\python\keras\saving\saved_model\model_serialization.py", line 56, in _get_serialized_attributes_internal
    super(ModelSavedModelSaver, self)._get_serialized_attributes_internal(
  File "E:\23620029-Faiz\.PROJECTS\AdaptiveDGP\venv\lib\site-packages\tensorflow\python\keras\saving\saved_model\layer_serialization.py", line 104, in _get_serialized_attributes_internal
    functions = save_impl.wrap_layer_functions(self.obj, serialization_cache)
  File "E:\23620029-Faiz\.PROJECTS\AdaptiveDGP\venv\lib\site-packages\tensorflow\python\keras\saving\saved_model\save_impl.py", line 163, in wrap_layer_functions
    call_fn_with_losses = call_collection.add_function(
  File "E:\23620029-Faiz\.PROJECTS\AdaptiveDGP\venv\lib\site-packages\tensorflow\python\keras\saving\saved_model\save_impl.py", line 505, in add_function
    self.add_trace(*self._input_signature)
  File "E:\23620029-Faiz\.PROJECTS\AdaptiveDGP\venv\lib\site-packages\tensorflow\python\keras\saving\saved_model\save_impl.py", line 420, in add_trace
    trace_with_training(True)
  File "E:\23620029-Faiz\.PROJECTS\AdaptiveDGP\venv\lib\site-packages\tensorflow\python\keras\saving\saved_model\save_impl.py", line 418, in trace_with_training
    fn.get_concrete_function(*args, **kwargs)
  File "E:\23620029-Faiz\.PROJECTS\AdaptiveDGP\venv\lib\site-packages\tensorflow\python\keras\saving\saved_model\save_impl.py", line 550, in get_concrete_function
    return super(LayerCall, self).get_concrete_function(*args, **kwargs)
  File "E:\23620029-Faiz\.PROJECTS\AdaptiveDGP\venv\lib\site-packages\tensorflow\python\eager\def_function.py", line 1299, in get_concrete_function
    concrete = self._get_concrete_function_garbage_collected(*args, **kwargs)
  File "E:\23620029-Faiz\.PROJECTS\AdaptiveDGP\venv\lib\site-packages\tensorflow\python\eager\def_function.py", line 1205, in _get_concrete_function_garbage_collected
    self._initialize(args, kwargs, add_initializers_to=initializers)
  File "E:\23620029-Faiz\.PROJECTS\AdaptiveDGP\venv\lib\site-packages\tensorflow\python\eager\def_function.py", line 725, in _initialize
    self._stateful_fn._get_concrete_function_internal_garbage_collected(  # pylint: disable=protected-access
  File "E:\23620029-Faiz\.PROJECTS\AdaptiveDGP\venv\lib\site-packages\tensorflow\python\eager\function.py", line 2969, in _get_concrete_function_internal_garbage_collected
    graph_function, _ = self._maybe_define_function(args, kwargs)
  File "E:\23620029-Faiz\.PROJECTS\AdaptiveDGP\venv\lib\site-packages\tensorflow\python\eager\function.py", line 3361, in _maybe_define_function
    graph_function = self._create_graph_function(args, kwargs)
  File "E:\23620029-Faiz\.PROJECTS\AdaptiveDGP\venv\lib\site-packages\tensorflow\python\eager\function.py", line 3196, in _create_graph_function
    func_graph_module.func_graph_from_py_func(
  File "E:\23620029-Faiz\.PROJECTS\AdaptiveDGP\venv\lib\site-packages\tensorflow\python\framework\func_graph.py", line 990, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "E:\23620029-Faiz\.PROJECTS\AdaptiveDGP\venv\lib\site-packages\tensorflow\python\eager\def_function.py", line 634, in wrapped_fn
    out = weak_wrapped_fn().__wrapped__(*args, **kwds)
  File "E:\23620029-Faiz\.PROJECTS\AdaptiveDGP\venv\lib\site-packages\tensorflow\python\keras\saving\saved_model\save_impl.py", line 527, in wrapper
    ret = method(*args, **kwargs)
  File "E:\23620029-Faiz\.PROJECTS\AdaptiveDGP\venv\lib\site-packages\tensorflow\python\keras\saving\saved_model\utils.py", line 169, in wrap_with_training_arg
    return control_flow_util.smart_cond(
  File "E:\23620029-Faiz\.PROJECTS\AdaptiveDGP\venv\lib\site-packages\tensorflow\python\keras\utils\control_flow_util.py", line 114, in smart_cond
    return smart_module.smart_cond(
  File "E:\23620029-Faiz\.PROJECTS\AdaptiveDGP\venv\lib\site-packages\tensorflow\python\framework\smart_cond.py", line 54, in smart_cond
    return true_fn()
  File "E:\23620029-Faiz\.PROJECTS\AdaptiveDGP\venv\lib\site-packages\tensorflow\python\keras\saving\saved_model\utils.py", line 170, in <lambda>
    training, lambda: replace_training_and_call(True),
  File "E:\23620029-Faiz\.PROJECTS\AdaptiveDGP\venv\lib\site-packages\tensorflow\python\keras\saving\saved_model\utils.py", line 167, in replace_training_and_call
    return wrapped_call(*args, **kwargs)
  File "E:\23620029-Faiz\.PROJECTS\AdaptiveDGP\venv\lib\site-packages\tensorflow\python\keras\saving\saved_model\save_impl.py", line 570, in call_and_return_conditional_losses
    call_output = layer_call(inputs, *args, **kwargs)
  File "E:\23620029-Faiz\.PROJECTS\AdaptiveDGP\venv\lib\site-packages\tensorflow\python\keras\engine\functional.py", line 424, in call
    return self._run_internal_graph(
  File "E:\23620029-Faiz\.PROJECTS\AdaptiveDGP\venv\lib\site-packages\tensorflow\python\keras\engine\functional.py", line 560, in _run_internal_graph
    outputs = node.layer(*args, **kwargs)
  File "E:\23620029-Faiz\.PROJECTS\AdaptiveDGP\venv\lib\site-packages\tensorflow\python\keras\engine\base_layer.py", line 1012, in __call__
    outputs = call_fn(inputs, *args, **kwargs)
  File "E:\23620029-Faiz\.PROJECTS\AdaptiveDGP\venv\lib\site-packages\tensorflow\python\keras\saving\saved_model\utils.py", line 73, in return_outputs_and_add_losses
    outputs, losses = fn(inputs, *args, **kwargs)
  File "E:\23620029-Faiz\.PROJECTS\AdaptiveDGP\venv\lib\site-packages\tensorflow\python\keras\saving\saved_model\utils.py", line 169, in wrap_with_training_arg
    return control_flow_util.smart_cond(
  File "E:\23620029-Faiz\.PROJECTS\AdaptiveDGP\venv\lib\site-packages\tensorflow\python\keras\utils\control_flow_util.py", line 114, in smart_cond
    return smart_module.smart_cond(
  File "E:\23620029-Faiz\.PROJECTS\AdaptiveDGP\venv\lib\site-packages\tensorflow\python\framework\smart_cond.py", line 54, in smart_cond
    return true_fn()
  File "E:\23620029-Faiz\.PROJECTS\AdaptiveDGP\venv\lib\site-packages\tensorflow\python\keras\saving\saved_model\utils.py", line 170, in <lambda>
    training, lambda: replace_training_and_call(True),
  File "E:\23620029-Faiz\.PROJECTS\AdaptiveDGP\venv\lib\site-packages\tensorflow\python\keras\saving\saved_model\utils.py", line 167, in replace_training_and_call
    return wrapped_call(*args, **kwargs)
  File "E:\23620029-Faiz\.PROJECTS\AdaptiveDGP\venv\lib\site-packages\tensorflow\python\keras\saving\saved_model\save_impl.py", line 544, in __call__
    self.call_collection.add_trace(*args, **kwargs)
  File "E:\23620029-Faiz\.PROJECTS\AdaptiveDGP\venv\lib\site-packages\tensorflow\python\keras\saving\saved_model\save_impl.py", line 420, in add_trace
    trace_with_training(True)
  File "E:\23620029-Faiz\.PROJECTS\AdaptiveDGP\venv\lib\site-packages\tensorflow\python\keras\saving\saved_model\save_impl.py", line 418, in trace_with_training
    fn.get_concrete_function(*args, **kwargs)
  File "E:\23620029-Faiz\.PROJECTS\AdaptiveDGP\venv\lib\site-packages\tensorflow\python\keras\saving\saved_model\save_impl.py", line 550, in get_concrete_function
    return super(LayerCall, self).get_concrete_function(*args, **kwargs)
  File "E:\23620029-Faiz\.PROJECTS\AdaptiveDGP\venv\lib\site-packages\tensorflow\python\eager\def_function.py", line 1299, in get_concrete_function
    concrete = self._get_concrete_function_garbage_collected(*args, **kwargs)
  File "E:\23620029-Faiz\.PROJECTS\AdaptiveDGP\venv\lib\site-packages\tensorflow\python\eager\def_function.py", line 1205, in _get_concrete_function_garbage_collected
    self._initialize(args, kwargs, add_initializers_to=initializers)
  File "E:\23620029-Faiz\.PROJECTS\AdaptiveDGP\venv\lib\site-packages\tensorflow\python\eager\def_function.py", line 725, in _initialize
    self._stateful_fn._get_concrete_function_internal_garbage_collected(  # pylint: disable=protected-access
  File "E:\23620029-Faiz\.PROJECTS\AdaptiveDGP\venv\lib\site-packages\tensorflow\python\eager\function.py", line 2969, in _get_concrete_function_internal_garbage_collected
    graph_function, _ = self._maybe_define_function(args, kwargs)
  File "E:\23620029-Faiz\.PROJECTS\AdaptiveDGP\venv\lib\site-packages\tensorflow\python\eager\function.py", line 3361, in _maybe_define_function
    graph_function = self._create_graph_function(args, kwargs)
  File "E:\23620029-Faiz\.PROJECTS\AdaptiveDGP\venv\lib\site-packages\tensorflow\python\eager\function.py", line 3196, in _create_graph_function
    func_graph_module.func_graph_from_py_func(
  File "E:\23620029-Faiz\.PROJECTS\AdaptiveDGP\venv\lib\site-packages\tensorflow\python\framework\func_graph.py", line 990, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "E:\23620029-Faiz\.PROJECTS\AdaptiveDGP\venv\lib\site-packages\tensorflow\python\eager\def_function.py", line 634, in wrapped_fn
    out = weak_wrapped_fn().__wrapped__(*args, **kwds)
  File "E:\23620029-Faiz\.PROJECTS\AdaptiveDGP\venv\lib\site-packages\tensorflow\python\keras\saving\saved_model\save_impl.py", line 527, in wrapper
    ret = method(*args, **kwargs)
  File "E:\23620029-Faiz\.PROJECTS\AdaptiveDGP\venv\lib\site-packages\tensorflow\python\keras\saving\saved_model\utils.py", line 169, in wrap_with_training_arg
    return control_flow_util.smart_cond(
  File "E:\23620029-Faiz\.PROJECTS\AdaptiveDGP\venv\lib\site-packages\tensorflow\python\keras\utils\control_flow_util.py", line 114, in smart_cond
    return smart_module.smart_cond(
  File "E:\23620029-Faiz\.PROJECTS\AdaptiveDGP\venv\lib\site-packages\tensorflow\python\framework\smart_cond.py", line 54, in smart_cond
    return true_fn()
  File "E:\23620029-Faiz\.PROJECTS\AdaptiveDGP\venv\lib\site-packages\tensorflow\python\keras\saving\saved_model\utils.py", line 170, in <lambda>
    training, lambda: replace_training_and_call(True),
  File "E:\23620029-Faiz\.PROJECTS\AdaptiveDGP\venv\lib\site-packages\tensorflow\python\keras\saving\saved_model\utils.py", line 167, in replace_training_and_call
    return wrapped_call(*args, **kwargs)
  File "E:\23620029-Faiz\.PROJECTS\AdaptiveDGP\venv\lib\site-packages\tensorflow\python\keras\saving\saved_model\save_impl.py", line 570, in call_and_return_conditional_losses
    call_output = layer_call(inputs, *args, **kwargs)
  File "E:\23620029-Faiz\.PROJECTS\AdaptiveDGP\venv\lib\site-packages\gpflux\layers\likelihood_layer.py", line 78, in call
    assert isinstance(inputs, tfp.distributions.MultivariateNormalDiag)
AssertionError

System information

st-- commented 3 years ago

The error is triggered when save_weights_only=False. For some reason this ends up calling the likelihood layer with a Tensor rather than a MultivariateNormalDiag...

vdutor commented 3 years ago

Thanks for raising the issue. This is a known bug: it is currently not possible to persist GPflux models, which rely on tfp.layers.DistributionLambda layers, beyond just storing the weights with save_weights_only=True. I will leave this issue open for better visibility.

mo-alaa commented 2 years ago

Is there any solution yet ?