bstriner / keras-adversarial

Keras Generative Adversarial Networks
MIT License
867 stars 231 forks source link

KL divergence in AAE #15

Open arashmh opened 7 years ago

arashmh commented 7 years ago

Regarding the example_aae.py: Can anyone explain how this code works without having the KL divergence included for the encoder loss?

Cheers

bstriner commented 7 years ago

Hi @arashmh Let's say you have some x and you want to autoencode it using some hidden representation z. A Variational Autoencoder lets you specify some prior belief in that z that you can try to analytically approximate some bound on the KL-divergence. This means you can now sample from x and calculate probabilities of x.

With an Adversarial Autoencoder, you don't specify z prior analytically, you just sample from it. The autoencoder tries to make its hidden representations look like the samples from your prior. Maybe with some types of networks with some scaling it might look like KL loss, but it really is a different animal entirely. With a WGAN it is hypothetically the wasserstein distance instead of KL, but it doesn't work out perfectly in practice yet.

If you read through the WGAN paper, it makes some really interesting arguments for why KL is a bad idea. https://arxiv.org/abs/1701.07875

Anyways, the outputs of the AAE model are the reconstruction mse, the discriminator value for fakes and the discriminator value for real.

The encoder/decoder part has two losses. It wants to maximize discriminator value for fakes and maximize reconstruction mse.

The discriminator part has two losses. It tries to maximize discriminator value for real and minimize discriminator value for fake.

In terms of how it is built, there are three unique losses. Two players each are coded to have all three losses, but the gradient is 0 in some cases. I build a base model with the three losses, then call AdversarialModel to split it into two different players. I then train it with 6 targets (3*2), but some with have 0 gradients so don't really matter.

Hope that all makes sense. If there are any actual lines of code you don't understand just let me know.

Cheers

arashmh commented 7 years ago

Hello @bstriner ! Thanks for your explanation , it now makes sense , I didnt know about that article , cool bro! I just wanted to know if the 'generated images' figure in the readme page for this AAE , is generated by the same code ? I ran it for 100 epochs and I couldnt get anything better than this:

generated-epoch-099

bstriner commented 7 years ago

The cifar example in the readme is a GAN. The AAE image in the readme is mnist. I'll have to rerun and post the output from the cifar AAE. Haven't run it recently so it might need some hyperparameter tweaking.

The important thing with AAEs (and also VAEs) is that you need to look at both generated and autoencoded images to diagnose the problem. Two common failures:

Other hyperparameters might need tweaking but those are two of the biggest issues.

I'll let you know if I get a chance to tweak it and rerun. End of semester going to be busy,

BTW, are those results on TF or theano? Shouldn't make a difference but worth checking.

Cheers

arashmh commented 7 years ago

Hey thanks for the explanation. I ran it on TF , but yeah , I'll give theano a try , although I dont think it really can make a difference , idk. The autoencoded images are also very poor. You can see a ghost of the object , but nothing close to an identifiable image. What can cause the overall poor performance ? Is it the second case you mentioned above ? autoencoded-epoch-099