igul222 / improved_wgan_training

Code for reproducing experiments in "Improved Training of Wasserstein GANs"
MIT License
2.35k stars 670 forks source link

Clarifications on code #1

Closed rafaelvalle closed 7 years ago

rafaelvalle commented 7 years ago

What's the output of Discriminator(interpolates)[0] on the code below in gan_language.py? Knowing that gen_cost = -tf.reduce_mean(Discriminator(fake_inputs)), I assume that Discriminator(interpolates) returns the discriminator evaluation's of the first batch in the interpolates, though this doesn't seem to make any sense.

gradients = tf.gradients(Discriminator(interpolates)[0], [interpolates])[0]  
slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1,2]))  
gradient_penalty = tf.reduce_mean((slopes-1.)**2)  
disc_cost += LAMBDA*gradient_penalty  
igul222 commented 7 years ago

Good catch! We missed this while cleaning up the code for release, it should just read gradients = tf.gradients(Discriminator(interpolates), [interpolates])[0]