tensorflow / tfjs

A WebGL accelerated JavaScript library for training and deploying ML models.
https://js.tensorflow.org
Apache License 2.0
18.52k stars 1.94k forks source link

Wrong predictions when using BatchNormalization with training flag set #562

Closed zaidalyafeai closed 4 years ago

zaidalyafeai commented 6 years ago

To get help from the community, check out our Google group.

TensorFlow.js version

latest

Browser version

Version 66.0.3359.139

Describe the problem or feature request

Batchnorm has wrong predictions when setting training = 1

Code to reproduce the bug / link to feature request

I created this simple keras model

def SimpleModel():
  x = Input(shape = (2, 2, 3))
  y = BatchNormalization()(x, training = 1)
  y = Flatten()(y)
  z = Dense(units = 1)(y)
  return Model(inputs = x, outputs = z)

After training, the batch norm layer weights are

[array([0.99683774, 0.99683774, 0.99683774], dtype=float32),
 array([-0.00316227,  0.00316227, -0.00316228], dtype=float32),
 array([-0.08008331,  0.01483306,  0.12279604], dtype=float32),
 array([1.0677528, 1.0555032, 0.9067482], dtype=float32)]

After running the prediction model.predict(np.zeros((1, 2, 2, 3))) The output

array([[[[-0.00316227,  0.00316227, -0.00316228],
         [-0.00316227,  0.00316227, -0.00316228]],

        [[-0.00316227,  0.00316227, -0.00316228],
         [-0.00316227,  0.00316227, -0.00316228]]]], dtype=float32)

On the browser the weights are the same but the activations are

Tensor
    [[[[0.0740574, -0.0112231, -0.1316395],
       [0.0740574, -0.0112231, -0.1316395]],

      [[0.0740574, -0.0112231, -0.1316395],
       [0.0740574, -0.0112231, -0.1316395]]]]

Explanation

on keras when setting training = 1, it uses the statics of the prediction sample

image

Tensorflow.js uses the stored moving mean and variance of the training data

image

zaidalyafeai commented 6 years ago

The same problem happens in TensorFlow as well.

caisq commented 6 years ago

@zaidalyafeai you mean tf.keras?

zaidalyafeai commented 6 years ago

No, this definition

https://www.tensorflow.org/api_docs/python/tf/layers/batch_normalization

zaidalyafeai commented 6 years ago

@caisq This is a quote from the TensorFlow page

training: Either a Python boolean, or a TensorFlow boolean scalar tensor (e.g. a placeholder). Whether to return the output in training mode (normalized with statistics of the current batch) or in inference mode (normalized with moving statistics). NOTE: make sure to set this parameter correctly, or else your training/inference will not work properly.

zaidalyafeai commented 6 years ago

@caisq, I am trying to understand the source code. Could you please explain to me what is broadcasting ?

nsthorat commented 6 years ago

@caisq did we ever resolve this issue? I assume tf.layers is doing the right thing in TensorFlow..

zaidalyafeai commented 5 years ago

I resolved this issue by modifying the source code and changing the definition of batch norm during inference time. My pix2pix demo is based on that!

caisq commented 5 years ago

Training with BatchNormazliation should be working. See the ACGAN example under review at https://github.com/tensorflow/tfjs-examples/pull/187

I'd like to see the code you're using and the change you made in order for it to work, @zaidalyafeai , if possible.

zaidalyafeai commented 5 years ago

@caisq, I may have accidentally deleted the source code :/ but the idea is simple I just forced batch norm layer to use the statistics of the input sample as if it was training. So, I didn't add any code just re-routing.

rthadur commented 4 years ago

Closing this due to lack of activity, feel to reopen. Thank you

FinderBMap commented 3 months ago

Small value of BatchNormalization parameter
"momentum" = 0.01 may help this (works for 1D in TensorFlow js)