eriklindernoren / PyTorch-GAN

PyTorch implementations of Generative Adversarial Networks.
MIT License
16.03k stars 4.03k forks source link

A bug in AAE implementation ? #152

Open Shentao-YANG opened 3 years ago

Shentao-YANG commented 3 years ago

Hi,

Thanks for this repo!

The following code from the aae.py file line 197 - 199 confuse me a bit.

real_loss = adversarial_loss(discriminator(z), valid)     
fake_loss = adversarial_loss(discriminator(encoded_imgs.detach()), fake)        
d_loss = 0.5 * (real_loss + fake_loss)

AFAIK discriminator(z) should constitute fake examples while encoded_imgs.detach()) true examples. The goal of the discriminator is to correctly classify the noise z as fake embedding and encoded_imgs.detach() as valid. Hence, I suggest the following modification to this block of codes:

real_loss = adversarial_loss(discriminator(encoded_imgs.detach()), valid)        
fake_loss = adversarial_loss(discriminator(z), fake)     
d_loss = 0.5 * (real_loss + fake_loss)

Please let me know if I misunderstanding anything.

sanjumsanthosh commented 2 years ago

Hi, This is what i understood, Based on the paper the role of the discriminator is to predict whether a sample is from the hidden latent code of auto encoder (encoded_imgs.detach())) or from a sampled distribution (z). So the training criterion here is to match the posterior distribution of latent space of autoencoder to that of the arbitrary prior distribution ( here z with normal distribution).

so by training like this

real_loss = adversarial_loss(discriminator(z), valid) 
fake_loss = adversarial_loss(discriminator(encoded_imgs.detach()), fake)
d_loss = 0.5 * (real_loss + fake_loss)

the encoder will generate the latent vectors close to our required prior ( closer to z). So after training, generating from any part of our prior z will produce a meaning full sample.

zhangslab commented 1 year ago

Hi, This is what i understood, Based on the paper the role of the discriminator is to predict whether a sample is from the hidden latent code of auto encoder (encoded_imgs.detach())) or from a sampled distribution (z). So the training criterion here is to match the posterior distribution of latent space of autoencoder to that of the arbitrary prior distribution ( here z with normal distribution).

so by training like this

real_loss = adversarial_loss(discriminator(z), valid) 
fake_loss = adversarial_loss(discriminator(encoded_imgs.detach()), fake)
d_loss = 0.5 * (real_loss + fake_loss)

the encoder will generate the latent vectors close to our required prior ( closer to z). So after training, generating from any part of our prior z will produce a meaning full sample.

trank you