google / compare_gan

Compare GAN code.
Apache License 2.0
1.82k stars 319 forks source link

unnecessary batchnorm updates #8

Closed po0ya closed 5 years ago

po0ya commented 6 years ago

First of all thank you for this thorough study! There is a slight bug carried out from the dcgan repo to here:

https://github.com/google/compare_gan/blob/615bdc6fc54e5c074adeee543b779dd504dc7e9f/compare_gan/src/gans/GAN.py#L36

In here we are passing is_training as a boolean which is true throughout training. This results in batchnorms getting updated when they shouldn't be:

https://github.com/google/compare_gan/blob/615bdc6fc54e5c074adeee543b779dd504dc7e9f/compare_gan/src/gans/abstract_gan.py#L252

when we are updating the discriminator weights, we want the batch norms of the generator to not get updated, the same when we are doing discriminator updates.

A separate is_training flag for the discriminator and generator which is fed-in seems to be the way to do it.

Marvin182 commented 6 years ago

Good catch!

We are aware of this bug (there is a test case intentionally ignoring it). We didn't fix it in this release to match the code that was used for the experiments in [1]. Going forward we probably want to fix it and rerun a few experiments. While not identical the "bug" is similar to having a lower decay rate in the batch norm.

[1] "The GAN Landscape: Losses, Architectures, Regularization, and Normalization" (https://arxiv.org/abs/1807.04720)

po0ya commented 6 years ago

Good to know that you are aware of it! Yes, in this flexible code that you have it's worth checking how much this bug matters.

Also, the decay rate argument gets stronger as training progresses ;)

techmatt commented 6 years ago

Along these lines, I'm a bit confused as to how this code updates in resnet_architecture.py line 52: return tf.contrib.layers.batchnorm( input, decay=0.9, updates_collections=None,

Overriding the default updates_collections=tf.GraphKeys.UPDATE_OPS. Then you do the expected

with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)): self.d_optim = tf.train.AdamOptimizer(

When defining the optimizers. How is tf.GraphKeys.UPDATE_OPS getting modified? Same is true of the batchnorm defined in ops.py.

Strangely, there doesn't seem to be an updates_collections for the non-contrib BN layer: https://www.tensorflow.org/api_docs/python/tf/layers/batch_normalization

(after some poking, see https://github.com/tensorflow/tensorflow/issues/21229 for the last point. Strange something as core as BN would be missing that option.)

Marvin182 commented 6 years ago

tf.GraphKeys.UPDATE_OPS is not getting modified for batch_norm [1]. Setting updates_collections=None will add nothing to any collection and add control dependencies inside the batch norm method [2].

[1] In fact tf.GraphKeys.UPDATE_OPS might be empty for some or all models, but this might change in the future, so I think it's good practice to keep the control_dependencies(). [2] Somewhere there: https://github.com/tensorflow/tensorflow/blob/r1.9/tensorflow/contrib/layers/python/layers/layers.py#L791