GeorgeSeif / Semantic-Segmentation-Suite

Semantic Segmentation Suite in TensorFlow. Implement, train, and test new Semantic Segmentation models easily!
2.5k stars 880 forks source link

Potential Batch Norm Architecture Bug for Encoder-Decoder #181

Open k22jung opened 5 years ago

k22jung commented 5 years ago

Information

Please specify the following information when submitting an issue:

https://arxiv.org/pdf/1502.03167.pdf

I am looking at your implementation of Encoder-Decoder and it seems to be that the Batch Norm is before the Relu, which was mentioned in the paper:

https://github.com/GeorgeSeif/Semantic-Segmentation-Suite/blob/master/models/Encoder_Decoder.py

Would you know if the difference of placing the Batch Norm before or after Relu matter at all when we are both training and inferencing? I only noticed this because I was trying to switch the is_training param on Batch Norm to be False for inferencing, but I'm getting poor results with the entire output encoding giving out very large values for the classes (~1e13). This issue may be related to what was described in https://github.com/GeorgeSeif/Semantic-Segmentation-Suite/issues/138

k22jung commented 5 years ago

I just want to add to this a bit. I just found out that there may be a crucial step missing from train.py that inhibits Batch Norm from working. I used this snippet to take a look at the moving mean and variance values, and with the default code the mean and variance did not change.

        '''
        subgraph_name = 'BatchNorm'#'resnet_v2_50/conv1'
        layer_node_names = ['beta','moving_mean','moving_variance']

        with sess.as_default():
            with tf.variable_scope(subgraph_name) as scope:
                tf.get_variable_scope().reuse_variables()

                for name in layer_node_names:
                    vals = tf.get_variable(name)
                    #print(weights)
                    w = np.array(tf.convert_to_tensor(vals).eval())
                    print('\n\n'+subgraph_name+'/'+name+':')
                    print(w)
        '''

This is an example output I got from training:

BatchNorm/beta:
[ 3.99581715e-03 -2.95882788e-03 -2.49152072e-03 -4.84126824e-04
  1.21722305e-02 -8.25406425e-03 -1.07911862e-02 -1.80888735e-03
  9.33513511e-03  1.45017356e-02 -2.84231792e-05 -7.19797099e-03
 -7.97156245e-03  2.18108855e-03  7.69345509e-03 -8.57602712e-03
  7.38091068e-03 -7.51969451e-03 -1.31652849e-02  1.20366050e-03
  1.62450515e-03 -1.98164186e-03 -1.94162584e-03  4.32464993e-03
  2.05287635e-02 -7.64263328e-03  1.11849699e-02 -7.16590881e-03
 -1.51986559e-03 -5.33675216e-03 -1.01162652e-02 -5.28684445e-03
 -2.47760070e-03  1.27251754e-02  1.13408931e-03  1.60832815e-02
  1.09339571e-02 -5.48161939e-03 -7.34117581e-03  1.20068187e-04
  3.08766356e-03 -4.66567138e-03  3.42956441e-03  8.49166978e-03
  4.71142307e-03  2.72153108e-03 -1.11686494e-02 -2.48624571e-03
 -9.10147186e-03  6.94223773e-03  2.87658046e-03  1.42992726e-02
 -3.58953816e-03  1.37312226e-02  1.24515565e-02 -1.05547924e-02
 -2.41558114e-03  1.25913359e-02  9.23583191e-03 -4.69615974e-04
  1.58485472e-02 -2.28446879e-05 -7.59366807e-03 -1.57165725e-03]

BatchNorm/moving_mean:
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]

BatchNorm/moving_variance:
[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]

May someone please confirm these findings? Going back to the Batch Norm doc (I'm using tf contrib, as I couldn't find anything for slim Batch Norm), it looks like there's something we need to add for changing the values for Batch Norm as noted here:

  update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
  with tf.control_dependencies(update_ops):
    train_op = optimizer.minimize(loss)
seushengchao commented 5 years ago

I have the same problem about the batch norm in the code.