tensorflow / tensorflow

An Open Source Machine Learning Framework for Everyone
https://tensorflow.org
Apache License 2.0
186.28k stars 74.31k forks source link

Saving of BatchNormalization layer fails #30811

Closed olesalscheider closed 5 years ago

olesalscheider commented 5 years ago

System information

Describe the current behavior When I try to save a BatchNormalization layer as in the example code it fails with the following error:

Traceback (most recent call last):
  File "test_bn.py", line 31, in <module>
    tf.saved_model.save(infer, saved_model_dir, signature_dict)
  File "/home/salscheider/tf2/lib/python3.6/site-packages/tensorflow_core/python/saved_model/save.py", line 840, in save
    meta_graph_def, saveable_view, signatures)
  File "/home/salscheider/tf2/lib/python3.6/site-packages/tensorflow_core/python/saved_model/save.py", line 536, in _fill_meta_graph_def
    object_map, resource_map, asset_info = saveable_view.map_resources()
  File "/home/salscheider/tf2/lib/python3.6/site-packages/tensorflow_core/python/saved_model/save.py", line 270, in map_resources
    "supported.").format(concrete_function.name, capture))
ValueError: Attempted to save a function b'__inference_batch_normalization_layer_call_and_return_conditional_losses_414' which references a symbolic Tensor Tensor("batch_normalization_trainable:0", dtype=bool) that is not a simple constant. This is not supported.

Describe the expected behavior Saving succeeds without error.

Code to reproduce the issue The following testcase can be used to reproduce the issue:

import tensorflow as tf

class Outer(tf.keras.Model):
    def __init__(self):
        super().__init__()

        self.bn = tf.keras.layers.BatchNormalization()

    def call(self, x, train_bn=False):
        return self.bn(x, training=train_bn)

class Infer(tf.Module):
    def __init__(self):
        super().__init__()

        # Decorate the inference function with tf.function
        self.infer_ = tf.function(self.infer, input_signature=[
             tf.TensorSpec([1, 64, 64, 8], tf.float32, 'prev_img')])

        self.outer = Outer()

    def infer(self, input):
        return self.outer(input, train_bn=False)

# Create model
infer = Infer()

# Save the trained model
signature_dict = {'infer': infer.infer_}
saved_model_dir = '/tmp/saved_model'
tf.saved_model.save(infer, saved_model_dir, signature_dict)
gadagashwini-zz commented 5 years ago

@olesalscheider I tried executing the code on Colab with Tensorflow 1.14.0. But I did not get any error. Please take a look at gist of Colab. Thanks!

olesalscheider commented 5 years ago

Oh, I forgot to mention: The code above used to work with older versions of Tensorflow (I think including 1.14.0). This is a regression in the current master branch (with tf2 API).

gadagashwini-zz commented 5 years ago

I am able to reproduce the issue with Tensorflow 2.0.0.beta1. Please take a look at gist here. Thanks!

tgs266 commented 5 years ago

I am getting the exact same error. Anybody have any way to fix this? I am using the nightly previews.

k-w-w commented 5 years ago

I believe this should be fixed in TF 2.0 RC 0. Closing this, but if you are still having issues please reopen this issue.

tensorflow-bot[bot] commented 5 years ago

Are you satisfied with the resolution of your issue? Yes No