jiamings / wgan

Tensorflow Implementation of Wasserstein GAN (and Improved version in wgan_v2)
240 stars 82 forks source link

Batch Normalization #1

Closed zijunwei closed 7 years ago

zijunwei commented 7 years ago

Thanks for sharing the code. The code is elegant and well structured, I'm going to pick this one as a starting point. But there might be a glitch in the batch normalization. Are you using batch normalization the in the same way during training and testing? Or am I missing anything in the paper suggesting using batch normalization in this way?

If you're using BN in the normal sense, it seems that in your training the moving_mean and moving_variance is not updated. Check this link: https://github.com/tensorflow/tensorflow/issues/1122 And perhaps also this: http://r2rt.com/implementing-batch-normalization-in-tensorflow.html A way to update the variables might be the following:

            with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
            update_ops = tf.no_op()

Please let me know if I miss anything

jiamings commented 7 years ago

Seems that it is the case. Will check it out after ICML.

sbrodehl commented 7 years ago

As stated in [1] batch normalization should not be used in the discriminator / critic:

We note that our method does work with normalization schemes which don’t introduce correlations between examples in a minibatch such as layer normalization (Ba et al., 2016), weight normalization (Salimans & Kingma, 2016) and instance normalization (Ulyanov et al., 2016). In particular, we recommend layer normalization as a drop-in replacement for batch normalization if desired.

There are a few BN layers in the code, but I haven't looked at the performance without them yet. Any thoughts?

[1] Improved Training of Wasserstein GANs https://arxiv.org/abs/1704.00028

jiamings commented 7 years ago

I am aware of that. However, if we do not use update_ops then the BN layer does not learn the batch statistics. Nonetheless, just remove the batch normalization for discriminator in your code.