igul222 / improved_wgan_training

Code for reproducing experiments in "Improved Training of Wasserstein GANs"
MIT License
2.35k stars 670 forks source link

Help with convergence #20

Closed anishpdoshi closed 7 years ago

anishpdoshi commented 7 years ago

Hi all,

My team is running WGAN with gradient penalty with a ResNet generator/discriminator as in the code. We're trying to train the gan to generate mage data- 64x64 black and white spectrograms (i.e one channel instead of three). However, our training runs keep suffering from seeming mode collapse. We get samples like:

samples_9999

both early on, and after a lot of iterations. We've played around with learning rates, the LAMBDA gp parameter, number of residual blocks, number of critic iters per gen iter, and even tried essentially pretraining the critic to ensure it was pretty optimal (at least in distinguishing noise from valid samples). We've doublechecked our data feeding algorithms to make sure that the real images are what they should be and get normalized correctly, in the same way as the imagenet example.

Any suggestions? We've been stuck for a while :(

anishpdoshi commented 7 years ago

Nevermind - we found our issue! Since we were pre-normalizing our data to be floats between -1.0 and 1.0, we realized that the first computation of the session:

all_real_data_conv = tf.placeholder(tf.int32, shape=[BATCH_SIZE, 1, 64, 64])

was casting our normalized floats to be integers, thereby losing pretty much all information. Changing 'tf.int32' to 'tf.float32' in above line of course solves this.