AntixK / PyTorch-VAE

A Collection of Variational Autoencoders (VAE) in PyTorch.
Apache License 2.0
6.46k stars 1.05k forks source link

Vanilla-VAE and usage of kld_weight #11

Closed ksachdeva closed 3 years ago

ksachdeva commented 3 years ago

Hi @AntixK

Many thanks for this great effort.

Based on my understanding so far the original VAE does not talk about weighing the kl_divergence_loss. Later beta-vae and many other papers made the case of weighing the kl_div (and essentially treat it as a hyper-parameter).

In your implementations, I see that you consistently use kld_weight = kwards['M_N'] = batch_size/num_of_images.

Is this a norm to select the weight for kl_div loss using the ratio of batch size and a number of images?

Since in the original VAE paper no weighing was done is it okay to use it in vanilla_vae.py?

Regards Kapil

AntixK commented 3 years ago

It is just the bias correction term for accounting for the minibatch. When small batch-sizes are used, it can lead to a large variance in the KLD value. But it should work without that kld_weight term too.

simonhessner commented 2 years ago

Related question: https://github.com/AntixK/PyTorch-VAE/issues/40