christiancosgrove / pytorch-spectral-normalization-gan

Paper by Miyato et al. https://openreview.net/forum?id=B1QRgziT-
MIT License
676 stars 110 forks source link

sn-wgan results #9

Open ferrine opened 5 years ago

ferrine commented 5 years ago

Hi, I'm trying to apply spectral normalization to wasserstein gans. I've failed to make it work in my project so tried your repository to get more intuition of how to train them. However I had no training progress for about a day of training.

In original WGAN paper they seem to use 25 epoches for training. I've waited for 130 so far and got the following results (with your code) image

Some introspection in discriminator gave me interesting insights. Indeed WGAN is Lipshitz with constant ~2.4 according to gradient norm histograms. However sn-gan or regular gan seem to have gradients with larger average norm. Below are gradient norms for trained for 100 epoches sn-gan and sn-wgan and gan. On cifar dataset (and mnist for regular gan), I used 5 discriminator iterations per generator update. image image image

Seems like sn-gan has better gradients for generator according to histograms. For regular GAN I see there are small gradients even for images from generator.

Did you manage to get satisfactory results for SN-WGAN and how if yes?

My current intuition says me that devil is in gradient or their bias (I don't think I have that biase as I use 1024 batch size). Convergence might be too slow because of these gradients.

christiancosgrove commented 5 years ago

One question... are you using the resnet model or the dcgan model?

This is very interesting...

A few months ago someone reported problems with sn-WGAN.

Have you tried using smaller batch sizes? I've seen that paper and am aware that WGAN has biased sample gradients... maybe it's not understood how this interacts with spectral norm?

ferrine commented 5 years ago

I used resnet model to report the first picture and my implementation without resnet for histograms. In my setup I always did 1 epoch discriminator pretraining. I tried small batch sizes and then moved to larger ones as I got poor results and found a paper about bias. Resnet model produced bad looking samples and did not seem to converge ever (using rmsprop for D, adam 0.5, 0.999 for G).

Loss function behaviour for resnet model was bad (that's why I tried to run your repo). I observed non stable high variance learning curves for fake and real critic scores (loss = fake_loss + real_loss, right signs are inside). There were no such effects for non resnet model. However in both cases I also observed divergence (|fake_loss - real_loss|) of critic scores for real and fake images. This divergence was only growing in time (same symptoms as in the link you provided). Since critic is Lipshitz I would expect it will only decrease in time (if I always have perfect critic). If Two distributions get closer, the the critic score cant get higher.

image source

At the barycenter it seems to be zero, but when domains overlap there are a lot of solutions providing minimum loss and zero wasserstein distance. (Maybe we can somehow help our critic?)

In toy problems like fitting 1d Gaussian the behaviour was first growing and then decreasing generator loss (as well as divergence I mentioned). BUT! When I get too close to optima I have oscillations for discriminator scores that influence the divergence and not wasserstein distance estimate.

Symptoms I get in GANs makes me think either I

HiddenMachine3 commented 1 month ago

I tried implementing vanilla wgan with spectral norm. Observing the gradient of the critic coming to the generator, it seems to be the case that when the generator almost starts getting better, the critic becomes worse; and worsens the generator in the process.

I can confirm because the gradients coming in from the critic at the start of training are in the general shape of my target images, but when the generator learns from this and gets a little better, the critic can no longer distinguish properly and learns bogus filters --> and feeds bogus gradients to the generator. The critic can no longer instruct the generator properly through its gradients.

I think the authors of the paper used spectral norm alongside Gradient penalty? Is that it?