igul222 / improved_wgan_training

Code for reproducing experiments in "Improved Training of Wasserstein GANs"
MIT License
2.36k stars 669 forks source link

Interpreting Critic Loss To Improve Convergence #26

Open NickShahML opened 7 years ago

NickShahML commented 7 years ago

Hey @igul222, thanks again for your help with comments earlier.

I've been rereading both WGAN papers, and I understand that the critic loss is supposed to estimate the wasserstein distance. Below is a few ideas I had to improve convergence.

Suppose you have a training where the critic loss starts at -170.0 and then tapers down to -30.0 by generator iteration 200k. Despite more training and lowering the learning rate, it stays at this -30.0.

Is it correct to interpret this finding as a problem with the generator? Shouldn't the generator architecture be designed in such a way so that this -30.0 approaches 0?

Another way of asking this is: To improve convergence, isn't it clear that it is generator's fault? The only way I could see it being the critic's fault is if the wasserstein distance approaches 0. In this case, it is probably that the critic doesn't fully capture the wasserstein distance.

From this line of thinking, isn't it wise to overpower the generator (increase generator's num of layers or dimensionality) until you hit a wasserstein distance of 0? Once, you hit this distance of 0ish, then it is justifiable to increase the critic's architecture.

alex-lew commented 7 years ago

This makes sense! But I'm not sure you can draw that conclusion. Think of the generator as a student and the critic as a teacher; the student/generator does some work and the teacher/critic points out its flaws so that the generator can improve. The architectures of the generator and critic each limit the functions they can compute; in our analogy, the student is only capable of learning some things and the teacher is only capable of teaching some things. The critic loss (or estimated Wasserstein distance) is a measure of how bad the teacher thinks the student's work is.

Now suppose the student, by modifying its weights, could produce (say) better spelling but not better punctuation on its essays. And suppose the teacher, by modifying its weights, could detect punctuation flaws but not spelling flaws. Then there's a mismatch, and the estimated Wasserstein distance flattens out (because no matter what the student does it cannot improve its work in the teacher's eyes, who keeps telling it to improve its punctuation to no avail). In this case improving either the student or the teacher will help. If we improve the teacher's capacity so it's capable of noticing spelling flaws, the estimated Wasserstein distance will be greater than 30, but the gradients will actually help the student learn something it is capable of learning: it will modify its weights to produce better spelling. Similarly, if we improve the student's capacity so that it can change its weights to punctuate better, suddenly the gradients from the teacher will be useful and the Wasserstein distance will go down.

tl;dr -- if the estimated Wasserstein distance is close to 0, it's safe to say you should improve the critic architecture before the generator architecture. But if not, my intuition is that gains might be made by improving either architecture, and the estimated Wasserstein distance alone may not be enough to tell you where to focus your efforts.

NickShahML commented 7 years ago

Thanks @alex-lew for the insightful commentary. In many of my natural language generation tasks, I have used very large generators and critics but it always seems that the critic's loss always converges to a fixed negative constant (e.g. -30)

However, my biggest issue with WGAN is that I know the generator's architecture can do better. By training is on max likelihood you can get good grammar and punctuation, which is a struggle to see in regular WGAN.

From this finding, you would think that it would have to be the critic's fault. But, I have confirmed that by training the discriminator on sigmoid cross entropy, it can easily distinguish between the two.

So this leads me to my final conclusion: Both the generator and discriminator architecture is sufficient. However it is the WGAN design itself that is flawed. Don't get me wrong, WGAN is an amazing breakthrough, but there is still something crucially wrong. Cramer GAN and other papers suggest different alternatives, but there is something still not right.

You would think also that you could generate really good images of faces with huge generators and critics but you can't. It may be too little training data, but I think there is something inherently wrong with WGAN.

li-zemin commented 6 years ago

@NickShahML @alex-lew Hello, I met the same issue as yours. In my case, I trained a WGAN-GP model on a 3D voxel dataset. I trained the generator for one time, and trained the discriminator for five times in a iteration. After training for 20k iterations, I found that, the distance between the fake_loss and real_loss is stable at a fixed value, about 30. Although the d_loss which consists the fake_loss, real_loss, and gradient_penalty_loss is decreasing in a much slow rate, the distance between the fake_loss and real_loss is almost invariable, only the gradient_penalty_loss is decreasing to less than 10. And as the gradient_penalty_loss with the lambda 10 and the ideal the gradient is 1, I think the desired the gradient_penalty_loss is 10. So I think the training is difficult to converge persistently. How do you like it?

ukuleleplayer commented 5 years ago

@li-zemin Hey! I'm in the exact same situation as you - the loss stabilises at -30. Did you manage to improve this?

Cheers!

priyarana commented 3 years ago

Hi . My loss dropped down to -0.17, if I train it further loss starts increasing. Shall i consider -0.17 as convergence point then. Any inputs please.

Blurryface0814 commented 1 year ago

@li-zemin嘿!我的情况和你完全相同 - 损失稳定在-30。 你设法改善这一点吗?

干杯!

Hi ! I met the same issue as yours. Have you ever fixed this problem ?

Thank you !

priyarana commented 1 year ago

Actually this is not an issue ! but this is how WGAN gets trained. During training, the loss value keep on dropping upto a certain point when it starts rising. That point is its convergence point which means model is trained now, and training needs to be stopped at that point (just before the rise of the value).

priyarana commented 1 year ago

Following paper has implemented WGAN -div and compared with other WGANs. https://www.nature.com/articles/s41598-022-22882-x, Refer to supplementary draft as well. Paper also explains the implementation of WGAN.