eriklindernoren / PyTorch-GAN

PyTorch implementations of Generative Adversarial Networks.
MIT License
16.41k stars 4.07k forks source link

Problem in ClusterGAN #75

Open zhiqi-li opened 5 years ago

zhiqi-li commented 5 years ago

The problem is in clustergan.py:

443:     ge_loss = torch.mean(D_gen) + betan * zn_loss + betac * zc_loss
465      d_loss = torch.mean(D_real) - torch.mean(D_gen) + grad_penalty

I think it shoud be


443     ge_loss = -torch.mean(D_gen) + betan * zn_loss + betac * zc_loss
465     d_loss = -torch.mean(D_real) + torch.mean(D_gen) + grad_penalty
yangjunting100 commented 4 years ago

I think you're right. Because that's what wgan's formula is supposed to be. Have you seen the difference between the two? thank you!

yangjunting100 commented 4 years ago

I did an experiment and found that the two formulas loss converges to almost the same. The quality of image generation is similar. Wonder if MNIST is too simple or the formula itself??