Newmu / dcgan_code

Deep Convolutional Generative Adversarial Networks
MIT License
3.44k stars 693 forks source link

Batch normalization and inference in the DCGAN model #24

Closed nikjetchev closed 8 years ago

nikjetchev commented 8 years ago

I am using the DCGAN code and pretty happy with the results. However, I am curious, should one not treat the Batch Normalization operation in a special way when doing inference (after training is completed) ?

i) the original BatchNorm paper mentions that we need to freeze the mean and variance when doing inference with the model https://arxiv.org/pdf/1502.03167v3.pdf , algorithm 2

ii) the DCGAN does not use this fixing of the statistics of the batch, so when we generate new samples with the _gen function it seem we calculate on the fly the batch norm statistics. This still works and produces nice images, to my surprise

iii) now here is a case when it does not work: start with a black image X and optimize it with respect to the discriminator function to make it close to the "true" images. With few iterations of gradient descent I can get an X image which is predicted as 1 (true), but it looks pretty much also black. So basically, the discriminator seems to be pretty bad in that case, even though the images I can generate are quite good. My guess would be that the batch normalization fails in that case, since the statistics of the single black image are totally different than the statistics of a proper random minibatch.

iv) has anyone implemented a fixing of the mini batch parameters for inference, as advocated in the original paper? This might be an useful option for the DCGAN code.

v) as next experiment, I will try to remove batch normalization and train without it, and than see whether my black image experiment will work correctly

if anyone has more insights about the use of batch normalization in the DCGAN it will be really helpful to discuss that, or to get the code for a simple modification of DCGAN in order to use fixed batch normalization operation when doing inference.

thanks a lot Nikolay

Newmu commented 8 years ago

You're correct that generation just uses minibatch-statistics (as long as minibatch to gen is of a decent size it works fine). Inference is used for the semi-supervised results - an example of inferring/calculating these values can be found in https://github.com/Newmu/dcgan_code/blob/master/svhn/svhn_semisup_analysis.py