wenxinxu / resnet-in-tensorflow

Re-implement Kaiming He's deep residual networks in tensorflow. Can be trained with cifar10.
MIT License
828 stars 276 forks source link

About batch normalization #6

Open utkarshojha opened 7 years ago

utkarshojha commented 7 years ago

The batch_normalization_layer() function doesn't compute the statistics of population i.e. population mean and variance. The part implemented is only taking care of the training procedure (batch statistics), but while testing one will need the population statistics

dongzhuoyao commented 7 years ago

yes, the bn implementation is wrong! please notice

jingjing-gong commented 7 years ago

agree, batch_normalization is wrong.

wenxinxu commented 7 years ago

Thank you very much for pointing it out. I will fix it.

WoNiuHu commented 7 years ago

@wenxinxu, @utkarsh2254 @dongzhuoyao hi , so could you please tell me how to fix this bug?

WoNiuHu commented 7 years ago

So, what's the difference ? pop_mean and pop_var will not changed during training and pop_var does not used.

Best regards.

At 2017-11-09 14:32:02, "Utkarsh Ojha" notifications@github.com wrote:

In the resnet.py file, batch_normalization_layer function can be modified like this `def batch_normalization_layer(input_layer, dimension): ''' Helper function to do batch normalziation :param input_layer: 4D tensor :param dimension: input_layer.get_shape().as_list()[-1]. The depth of the 4D tensor :return: the 4D tensor after being normalized ''' mean, variance = tf.nn.moments(input_layer, axes=[0, 1, 2]) beta = tf.get_variable('beta', dimension, tf.float32, initializer=tf.constant_initializer(0.0, tf.float32),trainable=False) gamma = tf.get_variable('gamma', dimension, tf.float32, initializer=tf.constant_initializer(1.0, tf.float32),trainable=False) pop_mean = tf.get_variable('pop_mean', dimension, trainable=False) pop_var = tf.get_variable('pop_var', dimension, trainable=False)

bn_layer = tf.nn.batch_normalization(input_layer, pop_mean, variance, beta, gamma, BN_EPSILON)

return bn_layer`

— You are receiving this because you commented. Reply to this email directly, view it on GitHub, or mute the thread.

dongzhuoyao commented 6 years ago

hi, @WoNiuHu , you should calculate the bn' mean, variance by EMA.

here is a practical demo you can follow: https://github.com/ppwwyyxx/tensorpack/blob/master/tensorpack/models/batch_norm.py#L161

TyrionChou commented 5 years ago

you can change tf.nn.batch_normalization to tf.layers.batch_normalization