keras-team / keras-contrib

Keras community contributions
MIT License
1.58k stars 650 forks source link

Is The WGAN Wasserstein Loss Function Correct? #280

Closed justinessert closed 6 years ago

justinessert commented 6 years ago

I'm struggling to understand why the Wasserstein Loss function is correct, as seen below:

def wasserstein_loss(y_true, y_pred):
    return K.mean(y_true * y_pred)

The comment in the code says that this works if your labels are -1 for generated images and 1 for real images. If we assume that disc_real and disc_fake are the outputs of the discriminator with only real and fake image, respectively, then:

y_true y_pred = 1 disc_real - 1 * disc_fake

Therefore, the loss functions that this defines are

loss_discriminator = 1 disc_real - 1 disc_fake loss_generator = - 1 * disc_fake (since the generator only updates on fake images)

The reason that I'm skeptical of this is that in the tensorflow implementation from the Improved Training of Wasserstein GANs the authors define the following losses:

# Standard WGAN loss
gen_cost = -tf.reduce_mean(disc_fake)
disc_cost = tf.reduce_mean(disc_fake) - tf.reduce_mean(disc_real)

So in other words, it would be:

loss_discriminator = - 1 disc_real + 1 disc_fake loss_generator = - 1 * disc_fake (since the generator only updates on fake images)

With the sign flipped on the discriminator loss, you would need to maximize instead of minimize loss right? Have the authors of this implementation seen success with it? @the-moliver @Rocketknight1 @DavidS3141

demmerichs commented 6 years ago

Your conclusion is not quite correct for the implementation in keras. As you can see from the definition of the Wasserstein loss, it clearly depends on the labels we feed into the model. So for the discriminator the we feed +1 as label for real and -1 for fake images, here your conclusion was correct. This is indeed opposite to the TF implementation, but the sign actually does not matter. As you noted, it is important, that the generator has the opposite sign for the fake images. This is in the current implementation the case, as we feed +1 as label for the fake images (no real images are trained, so no label for them needed). So we have

loss_generator = + 1 * disc_fake

I hope this cleared things up for you!

justinessert commented 6 years ago

It did, thanks for pointing that out!

JiamingJiao commented 6 years ago

Sorry but I still do not understand this:

This is indeed opposite to the TF implementation, but the sign actually does not matter.

I don't know why sign does not matter. Let's think about only training of discriminator. Assume that a label of 2 images is (-1, 1), which means (fake, real) If the discriminator gives perfectly correct result, which is (-1, 1), the average loss will be 1. If the discriminator cannot distinguish the inputs so its prediction is (0, 0), the loss will be 0. Since training tries to minimize the loss, the better discriminator should has a smaller loss right? But in the case that I said it is not.

demmerichs commented 6 years ago

The paper calls this second network not a discriminator but a critic. The important difference is, that the network does not try to match the labels y_true, but tries to minimize the loss K.mean(y_true * y_pred).

If the discriminator gives perfectly correct result, which is (-1, 1), the average loss will be 1.

Dont know how you reasoned this, but essential the critic tries to minimize the output y_pred for real images (this is +1 * y_pred) and maximize the output for fake images (as it tries to minimize -1 * y_pred). If you change the sign in this definition, the critic simply has to learn a sign flip compared to the previous setting, which should be not a problem. Also if the critic can't distinguish, we have that the loss is indeed 0, as the expectation value of y_pred is independent of y_true, and as half of your data is real (y_true=1) and half is fake (y_true=-1) the loss will be 0. If the critic is able to distinguish, it will have a loss smaller than 0, if it minimizes.

Maybe take another look into the original WGAN paper.

JiamingJiao commented 6 years ago

Thanks a lot! So, in your implementation, y_pred_on_real < 0, y_pred_on_fake >0 if the critic is well trained. And a better critic has larger absolute(y_pred). Is that right?

demmerichs commented 6 years ago

I think you got the right idea, yes! But to be precise, the actual answers to your questions would be "no". The reason for this is, that the critic can have an arbitrary offset, so if you have a critic C(x) with loss L, and another critic C'(x)=C(x)+K, which has just a shifted output by K, the loss would be still L, the same as for C, because the +K term would cancel for all the fake and real data. Because of this, it is possible, that a perfectly trained critic can have only positive or only negative values. And than there is also the gradient condition, which is making sure, that the y_pred cant get arbitraily large. So your second statement is in general true, but this scale absolute(y_pred) for which you have a good critic depends heavily on the used dataset.

justinessert commented 6 years ago

So I understand that we can define the losses with either sign, which would be why its ok to define the losses as disc_real - disc_fake instead of the author's disc_fake - disc_real if you also change the sign on the generator's loss, because you are essentially saying disc_fake-disc_real = - (disc_real-disc_fake).

But considering that the actual loss function is disc_fake - disc_real + grad_pen and the gradient penalty is defined the same way in both implementations, shouldn't we have to subtract the gradient penalty loss in this implementation (ie, multiply it by -1) in order to achieve an equivalent function? @DavidS3141

demmerichs commented 6 years ago

Sorry for coming back to you so late, I was on vacation.

The plus sign of the gradient penalty is still correct. You have to realize the following: If you change the sign of disc_fake - disc_real in the loss then the critic/discriminator network essential incorporates this sign flip in its network. That is why the loss w/o gradient penalty is always negative (independent of the choice of sign). However if the network flips it's sign, the absolute gradient value will be still the same everywhere (as we penalize the absolute value of the gradient) and positive of course. We want to minimize gradients larger than 1, so this loss has to contribute always with a plus sign to the total loss.