keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
61.64k stars 19.42k forks source link

Big problems with GAN and keras 2.0 because of new batchnorm #5892

Closed engharat closed 7 years ago

engharat commented 7 years ago

Hi, I work with Conditional GANs and in the last days I'm working with Conditional Wasserstein GAN keras implementation, starting from this code: https://github.com/tdeboissiere/DeepLearningImplementations/tree/master/WassersteinGAN the problem is that Conditional WGAN (and GANs and Conditional GANs too) don't produce correct output when batchnorm is not set on batchnorm_mode=2 I tried mode=0 before keras 2.0 and the generated images by the generator become indipendent by the noise, so I do not get any variability. So, in keras 2.0 I cannot work with WGAN/GAN anymore, being mode=2 removed by new batchnorm implementation. I think this is a very big issue and I hope the old batchnorm modes will be supported again, otherwise keras would not work on most of recent generative models.

Please make sure that the boxes below are checked before you submit your issue. If your issue is an implementation question, please ask your question on StackOverflow or join the Keras Slack channel and ask there instead of filing a GitHub issue.

Thank you!

bstriner commented 7 years ago

Lots of issues around this. The real question is, what should BN do for a GAN?

There are a lot of ways I could imagine using BN in a GAN, and I can't be sure about the best until someone tests them out and writes a paper.

Cheers

engharat commented 7 years ago

I think the several options you listed could be referred to the discriminator, and in fact we still don't know the appropriate way to use BN on it - I agree with your considerations. Anyway, the problem here refers to the generator - its BN influence heavily the generated images. And beside the theoretical analysies, still remains the main issue: in the new keras 2.0 have been removed a BN mode that is pivotal for the correctness of GAN models, and this should be addressed as soon as possible if we wanna see a broad adoption of keras 2.0.

bstriner commented 7 years ago

So you're having issues using BN in the generator? Using BN in the discriminator has tons of considerations but in the generator should be fine. Do you have something simple to demo the problem?

engharat commented 7 years ago

I have done some experiments about this issue:

Step to easily reproduce the problem:

-repeat the experiments with bn_mode=0: python main.py bn_mode=0 and you will see that the output will be garbage: in particular the 8 output images are always the same (beside a very very little variation), while each of those should be different - maybe it could be explained as a total GAN collapse.

adamcavendish commented 7 years ago

@engharat Do you mean the feature-wise normalization?

2: feature-wise normalization, like mode 0, but
                using per-batch statistics to normalize the data during both
                testing and training.
engharat commented 7 years ago

Exactly! Feature-wise normalization seems to be crucial for correct GAN output at test time

stale[bot] commented 7 years ago

This issue has been automatically marked as stale because it has not had recent activity. It will be closed after 30 days if no further activity occurs, but feel free to re-open a closed issue if needed.

mznyc commented 6 years ago

Has this issue been resolved? I ran into exactly the same problem after using Keras 2.0

engharat commented 6 years ago

I suppose batchnorm has not received any update, so the problem is not solved. Anyway, It can be solved pretty easily by cloning BatchNorm class into a new class, like BatchNormGAN or BatchNormMode2, and modifying the following code: return K.in_train_phase(normed, normalize_inference, training=training) to: return K.in_train_phase(normed, normalize_inference, training=True) so the batchnorm behaviour should be the same of the old bn with mode = 2. At least, it worked for me!

mznyc commented 6 years ago

Thank you for the quick response. After some digging, it seems the issue is with sharing discriminator in GAN and model.trainable does not work as expected. My loss curve also suggests that generative loss does not decrease. I tried several ways suggested elsewhere but didn't work.

engharat commented 6 years ago

I solved the problem of .trainable with the following function:

def make_trainable(net, value):
   net.trainable = value
   for l in net.layers:
      l.trainable = value

Where net is the model that you wanna freeze, and value is True/False if you wanna make it trainable/not.

mznyc commented 6 years ago

image

Here is my loss curve. I tried pretty much everything and still not able to converge. This is a W-GAN example.

engharat commented 6 years ago

We are going to off-topic, anyway: I tried with my proposed batchnorm and my proposed .trainable function on WGAN and it worked flawlessly. The loss doesn't not have to converge to zero, i think 0.70 and 0.60 are quite reasonable. I invite you to try again,to look at produced output, and if anything fails to look at some errors in code somewhere else, because this approach is actually working for me (and for several GAN researchers I suppose... ;)

mznyc commented 6 years ago

Appreciate your response. My problem here is that in pre-training mode, the discriminator does not function as expected. The accuracy is only roughly 50%. There is something wrong with the architecture.

mznyc commented 6 years ago

Finally, identified the problem as model insufficiency and unnecessary BatchNormalization. I removed BatchNormalization, which turned out unnecessary anyway. Here is result from WGAN-GP. I did not use make_trainable. Simply set/reset tradable property. Very nice and fast convergence with GP.

image