keras-team / keras-contrib

Keras community contributions
MIT License
1.58k stars 650 forks source link

What are the reported losses for the improved WGAN? #323

Closed AmirAlavi closed 5 years ago

AmirAlavi commented 5 years ago

In the improved wgan code, there is no printing/plotting of the loss. I'm trying to do this, but I'm confused by the output of train_on_batch for the Critic.

When I do print discriminator_model.metrics_names, I see this:

['loss', 'sequential_2_loss', 'sequential_2_loss', 'sequential_2_loss']

What are each of these? These don't line up with the model.compile line:

discriminator_model.compile(optimizer=Adam(0.0001, beta_1=0.5, beta_2=0.9),
                            loss=[wasserstein_loss,
                                  wasserstein_loss,
                                  partial_gp_loss])
SanderGielisse commented 5 years ago

First is the sum of the next three losses, where the next three are in the order as defined, wasserstein_loss, wasserstein_loss, partial_gp_loss.

For example, I get returned -5.6893063,-109.37327,102.22689,1.4570715, then -109.37327 + (wasserstein) 102.22689 + (wasserstein) 1.4570715 (gradient penalty) =-5,6893 (total loss) which is indeed what we expect from the first parameter = -5.6893063

AmirAlavi commented 5 years ago

Thanks! Make sense. The first one being the sum of all losses was the key piece I was missing.

ali-ehsan commented 5 years ago

@AmirAlavi What kind of results did you get on MNIST dataset. For me it shows black images after 9th epoch and never recovers. Please note I have not changed anything in the source code.