tensorflow / probability

Probabilistic reasoning and statistical analysis in TensorFlow
https://www.tensorflow.org/probability/
Apache License 2.0
4.24k stars 1.1k forks source link

Save the architecture of a Bayesian neural network #516

Open zhulingchen opened 5 years ago

zhulingchen commented 5 years ago

I have read the following issue posts: https://github.com/tensorflow/probability/issues/325 and https://github.com/tensorflow/probability/issues/289.

I know I can just save/load the weights of a BNN by simply using functions of model.save_weights and model.load_weights (though I actually used tf.keras.callbacks.ModelCheckpoint to implicitly save the weights with the best performance metrics).

However, my goal is to save the architecture of a Bayesian neural network (BNN). What I tried are:

model.to_json(): saw issue https://github.com/tensorflow/probability/issues/325 model.to_yaml(): can save but cannot load model.get_config(): can save but cannot load

Is there any workaround to save the architecture of a model that uses TensorFlow Probability layers to a file on the disk?

zhulingchen commented 5 years ago

I have tested under tfp-nightly with tf-nightly-gpu packages:

When I run: model_bak_from_yaml = tf.keras.models.model_from_yaml(model.to_yaml()), yields the following error:

Traceback (most recent call last):
  File "C:\ProgramData\Anaconda2\envs\nightly\lib\site-packages\IPython\core\interactiveshell.py", line 3326, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-15-53399e5e974e>", line 1, in <module>
    model_bak_from_yaml = tf.keras.models.model_from_yaml(model.to_yaml())
  File "C:\ProgramData\Anaconda2\envs\nightly\lib\site-packages\tensorflow_core\python\keras\engine\network.py", line 1430, in to_yaml
    return yaml.dump(self._updated_config(), **kwargs)
  File "C:\ProgramData\Anaconda2\envs\nightly\lib\site-packages\yaml\__init__.py", line 290, in dump
    return dump_all([data], stream, Dumper=Dumper, **kwds)
  File "C:\ProgramData\Anaconda2\envs\nightly\lib\site-packages\yaml\__init__.py", line 278, in dump_all
    dumper.represent(data)
  File "C:\ProgramData\Anaconda2\envs\nightly\lib\site-packages\yaml\representer.py", line 27, in represent
    node = self.represent_data(data)
  File "C:\ProgramData\Anaconda2\envs\nightly\lib\site-packages\yaml\representer.py", line 48, in represent_data
    node = self.yaml_representers[data_types[0]](self, data)
  File "C:\ProgramData\Anaconda2\envs\nightly\lib\site-packages\yaml\representer.py", line 207, in represent_dict
    return self.represent_mapping('tag:yaml.org,2002:map', data)
  File "C:\ProgramData\Anaconda2\envs\nightly\lib\site-packages\yaml\representer.py", line 118, in represent_mapping
    node_value = self.represent_data(item_value)
  File "C:\ProgramData\Anaconda2\envs\nightly\lib\site-packages\yaml\representer.py", line 48, in represent_data
    node = self.yaml_representers[data_types[0]](self, data)
  File "C:\ProgramData\Anaconda2\envs\nightly\lib\site-packages\yaml\representer.py", line 207, in represent_dict
    return self.represent_mapping('tag:yaml.org,2002:map', data)
  File "C:\ProgramData\Anaconda2\envs\nightly\lib\site-packages\yaml\representer.py", line 118, in represent_mapping
    node_value = self.represent_data(item_value)
  File "C:\ProgramData\Anaconda2\envs\nightly\lib\site-packages\yaml\representer.py", line 48, in represent_data
    node = self.yaml_representers[data_types[0]](self, data)
  File "C:\ProgramData\Anaconda2\envs\nightly\lib\site-packages\yaml\representer.py", line 199, in represent_list
    return self.represent_sequence('tag:yaml.org,2002:seq', data)
  File "C:\ProgramData\Anaconda2\envs\nightly\lib\site-packages\yaml\representer.py", line 92, in represent_sequence
    node_item = self.represent_data(item)
  File "C:\ProgramData\Anaconda2\envs\nightly\lib\site-packages\yaml\representer.py", line 48, in represent_data
    node = self.yaml_representers[data_types[0]](self, data)
  File "C:\ProgramData\Anaconda2\envs\nightly\lib\site-packages\yaml\representer.py", line 207, in represent_dict
    return self.represent_mapping('tag:yaml.org,2002:map', data)
  File "C:\ProgramData\Anaconda2\envs\nightly\lib\site-packages\yaml\representer.py", line 118, in represent_mapping
    node_value = self.represent_data(item_value)
  File "C:\ProgramData\Anaconda2\envs\nightly\lib\site-packages\yaml\representer.py", line 48, in represent_data
    node = self.yaml_representers[data_types[0]](self, data)
  File "C:\ProgramData\Anaconda2\envs\nightly\lib\site-packages\yaml\representer.py", line 207, in represent_dict
    return self.represent_mapping('tag:yaml.org,2002:map', data)
  File "C:\ProgramData\Anaconda2\envs\nightly\lib\site-packages\yaml\representer.py", line 118, in represent_mapping
    node_value = self.represent_data(item_value)
  File "C:\ProgramData\Anaconda2\envs\nightly\lib\site-packages\yaml\representer.py", line 48, in represent_data
    node = self.yaml_representers[data_types[0]](self, data)
  File "C:\ProgramData\Anaconda2\envs\nightly\lib\site-packages\yaml\representer.py", line 286, in represent_tuple
    return self.represent_sequence('tag:yaml.org,2002:python/tuple', data)
  File "C:\ProgramData\Anaconda2\envs\nightly\lib\site-packages\yaml\representer.py", line 92, in represent_sequence
    node_item = self.represent_data(item)
  File "C:\ProgramData\Anaconda2\envs\nightly\lib\site-packages\yaml\representer.py", line 48, in represent_data
    node = self.yaml_representers[data_types[0]](self, data)
  File "C:\ProgramData\Anaconda2\envs\nightly\lib\site-packages\yaml\representer.py", line 286, in represent_tuple
    return self.represent_sequence('tag:yaml.org,2002:python/tuple', data)
  File "C:\ProgramData\Anaconda2\envs\nightly\lib\site-packages\yaml\representer.py", line 92, in represent_sequence
    node_item = self.represent_data(item)
  File "C:\ProgramData\Anaconda2\envs\nightly\lib\site-packages\yaml\representer.py", line 52, in represent_data
    node = self.yaml_multi_representers[data_type](self, data)
  File "C:\ProgramData\Anaconda2\envs\nightly\lib\site-packages\yaml\representer.py", line 331, in represent_object
    if function.__name__ == '__newobj__':
AttributeError: 'functools.partial' object has no attribute '__name__'

When I run model_bak_from_config = tf.keras.models.model_from_config(model.get_config()), yields the following error:

Traceback (most recent call last):
  File "C:\ProgramData\Anaconda2\envs\nightly\lib\site-packages\IPython\core\interactiveshell.py", line 3326, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-16-6c3782a6529e>", line 1, in <module>
    model_bak_from_config = tf.keras.models.model_from_config(model.get_config())
  File "C:\ProgramData\Anaconda2\envs\nightly\lib\site-packages\tensorflow_core\python\keras\saving\model_config.py", line 55, in model_from_config
    return deserialize(config, custom_objects=custom_objects)
  File "C:\ProgramData\Anaconda2\envs\nightly\lib\site-packages\tensorflow_core\python\keras\layers\serialization.py", line 90, in deserialize
    layer_class_name = config['class_name']
KeyError: 'class_name'

Any ideas?

JimAva commented 4 years ago

Was there a solution for this?

Himscipy commented 4 years ago

@zhulingchen Did you found any solution ? Also could you share how did you saved your weights for the trained BNN model ?

malharjajoo commented 4 years ago

@zhulingchen , When you save the weights for a BNN, does that save a particular sample/instantiation of the weights (of network) or does it save them as distribution ??

I would hope that the weight distribution is saved ...

nbro commented 4 years ago

@brianwa84, @csuter, @jburnim, @srvasude, @jvdillon, @davmre, @SiegeLordEx, etc. Any plans to solve this issue?

jvdillon commented 4 years ago

Hi @nbro. Did you have in mind something different than the current serialization capability?

nbro commented 4 years ago

@jvdillon Has this issue https://github.com/tensorflow/probability/issues/325 been solved?

I remember I had faced a similar issue a few weeks ago. In any case, there are still issues. For example, if you execute the following code, you will get the error ValueError: Unknown layer: Conv2DFlipout (both with TF 2.0 (or 2.1) and TFP 0.8 (or 0.9)).

from __future__ import print_function

import tensorflow as tf
import tensorflow_probability as tfp
import numpy as np

file_path = "my_model.h5"
json_file_path = "my_model.json"

def get_bayesian_model(input_shape, num_classes=10):
    model_input = tf.keras.layers.Input(shape=input_shape)

    x = tfp.layers.Convolution2DFlipout(6, kernel_size=5, padding="SAME", activation=tf.nn.relu)(model_input)

    x = tf.keras.layers.Flatten()(x)

    x = tfp.layers.DenseFlipout(84, activation=tf.nn.relu)(x)
    model_output = tfp.layers.DenseFlipout(num_classes)(x)

    model = tf.keras.Model(model_input, model_output)

    return model

def get_mnist_data(normalize=True, to_binary=True, num_classes=10):
    (x_train, y_train), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()

    if tf.keras.backend.image_data_format() == 'channels_first':
        x_train = np.expand_dims(x_train, 1)
        x_test = np.expand_dims(x_test, 1)
    else:  # This should be TF ordering.
        x_train = np.expand_dims(x_train, -1)
        x_test = np.expand_dims(x_test, -1)

    input_shape = x_train.shape[1:]

    x_train = x_train.astype("float32")
    x_test = x_test.astype("float32")

    y_train = y_train.astype("int32")
    y_test = y_test.astype("int32")

    if normalize:
        x_train /= 255
        x_test /= 255

    if to_binary:
        y_train = tf.keras.utils.to_categorical(y_train, num_classes)
        y_test = tf.keras.utils.to_categorical(y_test, num_classes)
    return x_train, y_train, x_test, y_test, input_shape

def save_bayesian_model():
    x_train, y_train, x_test, y_test, input_shape = get_mnist_data()

    model = get_bayesian_model(input_shape)

    model_json = model.to_json()
    with open(json_file_path, "w") as json_file:
        json_file.write(model_json)

    # model.save(file_path)

def load_bayesian_model():
    with open(json_file_path, 'r') as json_file:
        loaded_model_json = json_file.read()
        tf.keras.models.model_from_json(loaded_model_json)

    # model = tf.keras.models.load_model(file_path)

if __name__ == '__main__':
    save_bayesian_model()
    load_bayesian_model()

A similar issue occurs if you use Keras' save method (just uncomment the commented lines). See also my comments https://github.com/tensorflow/probability/issues/325#issuecomment-574922374 and https://github.com/tensorflow/probability/issues/325#issuecomment-576867703.

joaocaldeira commented 4 years ago

On TF 2.0 and TFP 0.8, getting a similar ValueError: Unknown layer: DenseFlipout error when loading a keras model saved into an hdf5 file using model.save. Not ideal that there is no error when saving, so there is no way to recover the model if you no longer have it on memory.

If I instead save it using tfk.models.save_model, I get instead a ValueError: Unknown loss function:<lambda> when loading. Again no error when saving. My loss function is negloglik = lambda y, rv_y: -rv_y.log_prob(y).

Is there a way to save networks with TFP layers that currently works?

nbro commented 4 years ago

@jvdillon, @brianwa84 Is there a workaround to save and load the whole model (architecture, optimizer state, weights, etc.) using model.save and respectively tf.keras.models.load_model (or any other method), in the meantime? What's your suggestion?

I was trying to specify the TFP layers as custom objects when loading the model (see https://github.com/keras-team/keras/issues/4871)

import tensorflow as tf
import tensorflow_probability as tfp
tf.keras.models.load_model(model_file_path,
                           custom_objects={"Conv2DFlipout": tfp.layers.Convolution2DFlipout,
                                           "DenseFlipout": tfp.layers.DenseFlipout})

Although I don't get the errors ValueError: Unknown layer: Conv2DFlipout or ValueError: Unknown layer: DenseFlipout, I get another error while deserializing, i.e. TypeError: 'str' object is not callable. You can use the example above: just modify it to pass the custom objects, and you should get this just mentioned error.

(I am using TF 2.1 and TFP 0.9).

chrissype commented 4 years ago

@nbro Please do, also this feels like it deserves a new issue given that the parent post is making a simple no custom_objects error and this bug is completely separate.

nbro commented 4 years ago

@chrissype Well, this issue is related to the way TFP models can be saved. Right now, you cannot save and load the whole TFP model without getting errors. I will maybe create a more specific issue.

YutianPangASU commented 3 years ago

I am having a similar issue with TypeError: 'str' object is not callable with the custom layer setting. Is there a walkaround for this?

cdguarnizo commented 3 years ago

Hi, I'm getting the same issue. Has anyone found a solution to this problem?

waldnerf commented 3 years ago

This post worked for me

lixuanze commented 3 years ago

I still have similar issues

lixuanze commented 3 years ago

Anyone know how to solve TypeError: 'str' object is not callable?