carpedm20 / DCGAN-tensorflow

A tensorflow implementation of "Deep Convolutional Generative Adversarial Networks"
http://carpedm20.github.io/faces/
MIT License
7.14k stars 2.63k forks source link

Batch norm update bug #289

Open po0ya opened 6 years ago

po0ya commented 6 years ago

https://github.com/carpedm20/DCGAN-tensorflow/blob/60aa97b6db5d3cc1b62aec8948fd4c78af2059cd/model.py#L223

In here that you're intending to just update the discriminator, but you are updating the batch_norm parameters of the generator too.

same holds for here: https://github.com/carpedm20/DCGAN-tensorflow/blob/60aa97b6db5d3cc1b62aec8948fd4c78af2059cd/model.py#L232

You are updating self.g_loss where inside it uses the discriminator weights that are instantiated with train=True.

youngleec commented 6 years ago

Hi, I guess it is not a bug.

self.d_bn1 = batch_norm(name='d_bn1')
self.d_bn2 = batch_norm(name='d_bn2')
t_vars = tf.trainable_variables()
self.d_vars = [var for var in t_vars if 'd_' in var.name]
d_optim = tf.train.AdamOptimizer(config.learning_rate, beta1=config.beta1).minimize(self.d_loss, var_list=self.d_vars)

The above codes make sure that parameters of the discriminator (including the batch_norm parameters of the discriminator) are updated and parameters of the generator are not updated.

The train=True is to set the batch normalization in training mode, in which it use the mean and the variance of the mini-batch data. When it is not in training mode (e.g. testing with only one sample), it would use the moving mean and the moving variance which are accumulated in training mode.

Hope it helps.

po0ya commented 6 years ago

Right, but the batch_norm's mean moving average of the generator is updated because the output of the generator is used in the discriminator loss and train=True. Just passing in the variables to Adam Optimizer does not ensure the means are not updated, that's why batch_norm has an explicit is_training flag.

youngleec commented 6 years ago

Oh I see. It may be a small bug. Because updates_collections=None force the updates for the moving mean and the moving variance in place. We may need to set updates_collections and add the dependency explicitly. Thanks a lot.

po0ya commented 6 years ago

This small bug might mess things up! I have implemented another version without this bug and it acts in a funny way .. I still think there's a bug in my implementation since I just started working on it (it's not based on this code), but this might end up to be another thing that we should take for granted:

updating the batch norms of the discriminator and generator regardless of the 
weights in all iterations.
MihaiDogariu commented 6 years ago

Hi. @po0ya, I believe you are right about the batch normalization. This is the reason why when generating samples during the training process you obtain losses quite different from the ones that are being computed during the generator training. This discrepancy is expected to grow with the number of iterations as the moving mean and variance tend to stabilize around the dataset's statistics. Note that the sampler uses the option train=False when it calls the g_bn# processes, whereas the generator defaults to train=True.

This leads to the following problem:

I believe that preventing the network from using the moving mean and variance might solve this issue. However, the batch normalization is expected to become less effective. It is a trade-off worth investigating.

etienne-v commented 5 years ago

Hi @po0ya @youngleec . Do you have any solution or advice on the correct way to implement batch norm in this case? What I currently do is the following: I use placeholders for the training flags for when d and g are training:

d_istraining_ph = tf.placeholder(...)
g_istraining_ph = tf.placeholder(...)

When I build the generator and discriminator graph I pass the correct placeholders to d and g:

gz = generator(z, is_training=g_istraining_ph)
d_real = discriminator(x, is_training=d_istraining_ph)
d_fake = discriminator(gz, is_training=d_is_training_ph, reuse=False)

... where I've implemented the tf.layers.batch_normalization(...., training=is_training, ...) function inside the generator and discriminator models and the reuse flag is set to True to reuse the same variables for the second call to discriminator.

When building the training operations I collect the associated trainable variables:

d_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='discriminator')
g_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator')

where I've used tf.variable_scope('discriminator', reuse=reuse) and tf.variable_scope('generator') to wrap the discriminator and generator in scopes when I defined them.

The optimizers, gradient calculations and update steps are as follow:

d_opt = tf.train.AdamOptimizer(.....)
g_opt = tf.train.AdamOptimizer(.....)

d_grads_vars = d_opt.compute_gradients(d_loss, d_vars)
g_grads_vars = g_opt.compute_gradients(g_loss, g_vars)

with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope='discriminator')):
    d_train = d_opt.apply_gradients(d_grads_vars)
with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope='generator')):
    g_train = g_opt.apply_gradients(g_grads_vars)

When calling d_train during training, I set d_istraining=True and g_istraining=False, which will force the generator to use the running statistics but not so for the discriminator. In this case I suppose the discriminator's batch norm statistics will be updated with both real data and fake data from the generator. When calling g_train during training I do the opposite by setting d_istraining=False and g_istraining=True.

Is my reasoning correct?