tensorflow / probability

Probabilistic reasoning and statistical analysis in TensorFlow
https://www.tensorflow.org/probability/
Apache License 2.0
4.27k stars 1.11k forks source link

Sequential Model Save/Load Problems #755

Open piotrlaczkowski opened 4 years ago

piotrlaczkowski commented 4 years ago

I have a Tensorflow 2.x model which is using the TF preprocessing layer (tf.keras.layers.DenseFeatures) and the distributional layer from TF probability (DistributionLambda)

def regression_deep1_proba2(preprocessing_layer, feature_layer_inputs, model_name='test_model'):

    model = tf.keras.Sequential([
        preprocessing_layer,
        tf.keras.layers.Dense(100, activation='relu', name='hidden_1'),
        tf.keras.layers.Dense(50, activation='relu', name='hidden_2'),
        tf.keras.layers.Dense(1 + 1, name='output'),
        tfp.layers.DistributionLambda(
            lambda t: tfd.LogNormal(loc=t[..., :1], scale=tf.math.softplus(0.05 * t[..., 1:]))
        ),
    ])

    # ____________________ COMPILE WITH  ____________________________________________
    optimizer = tf.keras.optimizers.Adam()
    negloglik = lambda y, p_y: -p_y.log_prob(y)

    metrics = [
        tf.keras.metrics.MeanAbsolutePercentageError()
        ]

    model.compile(
        loss=negloglik,
        optimizer=optimizer,
        metrics=metrics
    )

    # ____________________ CALLBACKS DEFINITION ___________________________________________
    tbCallBack = tf.keras.callbacks.TensorBoard(
        log_dir=f'./logs_regression/{model_name}',
        update_freq='batch',
        histogram_freq=1,
        embeddings_freq=1,
        write_graph=True,
        write_images=True
    )

    # Create a callback that saves the model's weights every 5 epochs
    cp_callback = tf.keras.callbacks.ModelCheckpoint(
        filepath=f'./weights.{model_name}.hdf5',
        verbose=1,
        save_weights_only=True,
        save_best_onlt=True,
        monitor='MeanSquaredError'
    )
    early_stop = tf.keras.callbacks.EarlyStopping(
        monitor='MeanSquaredError',
        patience=2
    )
    callbacks_list = [tbCallBack, cp_callback, early_stop]

    return model, callbacks_list

I can get some nice results for the regression problem with this model, but when I save it for further use I can't load it back anymore (I have tried all online tutorials and solutions, but nothing is working)!!

I can save it to a file (h5, tf, json etc...) i.e.:

tf.keras.models.save_model(model, 'model_name.h5')

but when loading I get:

ValueError: ('We expected a dictionary here. Instead we got: ', <tf.Tensor 'Placeholder:0' shape=(None,) dtype=float32>)

I can't figure out what am I doing wrong - any help would be appreciated!

Also, I have tried all possible save extensions and backends: h5, tf, json, simple weights and other formats but none of them works ... I have even tried to do it on different systems: Mac, Ubuntu and on different Tensorflow versions: 2 and 2.1 ...

Of course, all the saving and loading works great for other models I use without the TF Probability layer (even the ones with a DenseFeatures layer).

harsh306 commented 4 years ago

I guess this is the issue. But this is also not resolved https://github.com/tensorflow/tensorflow/issues/31927

lixuanze commented 3 years ago

The issue is not resolved on my side as well. Can someone from the tensorflow team take a look and help with this?

piotrlaczkowski commented 3 years ago

I have managed to understand several serialisation issues in TF more deeply.

It seams that everything related to preprocessing layers or some custom losses or custom function is not getting well serialised -> or I should rather say rendered by the tool called autograph - which handles exporting correct namings into a serialised graph/files etc.

These issues were finally solved in TF 2.4 but reappeared magically again in TF 2.5 -> so the only version I can actually work with and use this kind of model is TF2.4.

If you are interested in some more details about this kind of model or problems with related architecture -> do not hesitate to contact me directly. (If enough people will be interested I can publish a blog about my deep dive into this isse and all the solutions I have implemented)

Have fun!

JohnTaylor2000 commented 2 years ago

This looks like continues to be a problem with saving TensorFlow probability layers as indicated previously.

I was not able to read in a saved model so switched to building a model and reading weights. This worked with a model with a tfp.layers.DistributionLambda layer. However I now have a model with a tfp.layers.DenseVariational layer and reading weights now also fails. Wondering if this has been fixed or there is a workaround:-

I am using:-

tensorflow 2.7.0 tensorflow-estimator 2.7.0 tensorflow-probability 0.15.0 h5py 3.6.0

Traceback (most recent call last): File "/DOWNSCALE_project/test_tfd_infer.py", line 210, in unet.load_weights('/DOWNSCALE_project/unet1_downscale_ERA5_GPU_save_weights.h5') File "/miniconda3/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 67, in error_handler raise e.with_traceback(filtered_tb) from None File "h5py/_objects.pyx", line 54, in h5py._objects.with_phil.wrapper File "h5py/_objects.pyx", line 55, in h5py._objects.with_phil.wrapper File "/miniconda3/lib/python3.9/site-packages/h5py/_hl/group.py", line 305, in getitem oid = h5o.open(self.id, self._e(name), lapl=self._lapl) File "h5py/_objects.pyx", line 54, in h5py._objects.with_phil.wrapper File "h5py/_objects.pyx", line 55, in h5py._objects.with_phil.wrapper File "h5py/h5o.pyx", line 190, in h5py.h5o.open KeyError: "Unable to open object (object 'input_1' doesn't exist)"

Bornesf commented 1 year ago

@JohnTaylor2000 Hello, I have the same problem, could I have any tips please

piotrlaczkowski commented 1 year ago

The issue was fixed in the TF version > 2.6 (if I remember well) so if you use the latest version this same model works ! ;) It was related to the internal package error that was propagated through several versions and finally fixed in 2022.

Hope this helps

Bornesf commented 1 year ago

Thanks,That really works