tensorflow / probability

Probabilistic reasoning and statistical analysis in TensorFlow
https://www.tensorflow.org/probability/
Apache License 2.0
4.16k stars 1.08k forks source link

Saving and Loading Bayesian Neural Network #289

Open gioCanelita opened 5 years ago

gioCanelita commented 5 years ago

I'd been looking at the Tensorflow Probability library and trying to modify the example in bayesian network example, hoping that I can save checkpoints and then restore them.

I first started trying to use tf.train.Checkpoint but, although, I was not getting any error when saving nor when restoring, it didnt seem to restart the training from the previous checkpoint as the accuracy was completely different value. I then tried using tf.keras.models.model.save, which again does save a file, but when trying to restore, I get the error: ValueError: Unknown layer: Conv2DFlipout when it is trying to deserialise the layer.

I then proceeded to add the custom objects Conv2DFlipout and DenseFlipout, and now I am getting the following error, and I am now a bit lost :( on what to do

File "C:/Users/ezzgm/Documents/BayesianNeural Network/bayesian_NN_vKerasSaver.py", line 188, in main model = tf.keras.models.load_model(FLAGS.model_dir+'checkpoint.hdf5')

File "C:\Users\ezzgm\AppData\Local\conda\conda\envs\tensorflow1.12\lib\site-packages\tensorflow\python\keras\engine\saving.py", line 230, in load_model model = model_from_config(model_config, custom_objects=custom_objects)

File "C:\Users\ezzgm\AppData\Local\conda\conda\envs\tensorflow1.12\lib\site-packages\tensorflow\python\keras\engine\saving.py", line 310, in model_from_config return deserialize(config, custom_objects=custom_objects)

File "C:\Users\ezzgm\AppData\Local\conda\conda\envs\tensorflow1.12\lib\site-packages\tensorflow\python\keras\layers\serialization.py", line 64, in deserialize printable_module_name='layer')

File "C:\Users\ezzgm\AppData\Local\conda\conda\envs\tensorflow1.12\lib\site-packages\tensorflow\python\keras\utils\generic_utils.py", line 173, in deserialize_keras_object list(custom_objects.items())))

File "C:\Users\ezzgm\AppData\Local\conda\conda\envs\tensorflow1.12\lib\site-packages\tensorflow\python\keras\engine\network.py", line 1302, in from_config process_node(layer, node_data)

File "C:\Users\ezzgm\AppData\Local\conda\conda\envs\tensorflow1.12\lib\site-packages\tensorflow\python\keras\engine\network.py", line 1260, in process_node layer(input_tensors[0], **kwargs)

File "C:\Users\ezzgm\AppData\Local\conda\conda\envs\tensorflow1.12\lib\site-packages\tensorflow\python\keras\engine\base_layer.py", line 746, in call self.build(input_shapes)

File "C:\Users\ezzgm\AppData\Local\conda\conda\envs\tensorflow1.12\lib\site-packages\tensorflow_probability\python\layers\conv_variational.py", line 188, in build self.trainable, self.add_variable)

File "C:/Users/ezzgm/AppData/Local/conda/conda/envs/tensorflow1.12/lib/site-packages/tensorflow_probability/python/layers/util.py", line 186, in _fn loc, scale = loc_scale_fn(dtype, shape, name, trainable, add_variable_fn)

TypeError: 'str' object is not callable

This is the modified code of the Bayesian Neural Network (I am trying to save in the first 10 steps just for debugin purposes and then rerunning the code to try and load the saved model). Any help on this will be very much appreciated!

`

   if FLAGS.architecture == "resnet":
       model_fn = bayesian_resnet.bayesian_resnet
   else:
       model_fn = bayesian_vgg.bayesian_vgg

  model = model_fn(
       IMAGE_SHAPE,
       num_classes=4,
       kernel_posterior_scale_mean=FLAGS.kernel_posterior_scale_mean,
       kernel_posterior_scale_constraint=FLAGS.kernel_posterior_scale_constraint)

  #check if saved checkpoint exists
  exists = os.path.isfile(FLAGS.model_dir+"checkpoint.hdf5")
  if exists:
       with   tf.keras.utils.CustomObjectScope({'DenseFlipout':tfp.layers.DenseFlipout,'Conv2DFlipout':tfp.layers.Convolution2DFlipout}):
      model = tf.keras.models.load_model(FLAGS.model_dir+'checkpoint.hdf5') 

  logits = model(images)
  labels_distribution = tfd.Categorical(logits=logits`
Himscipy commented 4 years ago

Hi @gioCanelita , Were you able to find a fix to the problem ? I am facing similar issue. Any inputs will be really appreciated.

JonvoWoo commented 4 years ago

I have a similar problom, but I followed the link below: https://github.com/tensorflow/probability/issues/325#issuecomment-477213850 before load the model, i create the same model, then model.complie() and model.build() i loaded the model successfully.