hwalsuklee / tensorflow-generative-model-collections

Collection of generative models in Tensorflow
Apache License 2.0
3.91k stars 858 forks source link

WGAN_GP - tensorflow #14

Closed Naxter closed 6 years ago

Naxter commented 6 years ago

Hey, I appreciate your work! You make my life better.

I found (maybe) a small bug in your WGAN_GP code. When calculating gradient penalty, you write:

D_inter,_,_=self.discriminator(interpolates, is_training=True, reuse=True) 
gradients = tf.gradients(D_inter, [interpolates])[0]

You use the sigmoid output of the Discriminator, not the logits.

In the original implementation (https://github.com/igul222/improved_wgan_training/blob/master/gan_mnist.py), they write this:

gradients = tf.gradients(Discriminator(interpolates), [interpolates])[0]

Here, the authors only return the logits of the Discriminator. So they use the logits for this calculation.

Did you do this on purpose?

Greetings!

hwalsuklee commented 6 years ago

Hi. I'm really appreciate for your comment.

You are right. I made mistakes in using output of the Discriminator. It must be logits of the Discriminator.

I fixed codes.

Thanks again.