zalandoresearch / pytorch-vq-vae

PyTorch implementation of VQ-VAE by Aäron van den Oord et al.
MIT License
534 stars 101 forks source link

EMA update before quantization #5

Closed stangelid closed 5 years ago

stangelid commented 5 years ago

Hi and thanks for providing a nice and clean implementation of VQ-VAEs :)

While playing around with your code, I noticed that in VectorQuantizerEMA you first perform the EMA update of the codebook counts and embeddings, and then use the updated codebook embeddings as the quantized vectors (and for computing e_latent_loss).

In particular, the order in which you perform operations is: 1) Nearest neighbour search 2) EMA updates 3) Quantization 4) e_latent_loss computation

Is there a reason why you do the EMA updates before steps 3 and 4? My intuition says that the order should be: 1) Nearest neighbour search 2) Quantization 3) e_latent_loss computation 4) EMA updates

Looking forward to hearing your thoughts!

Many thanks, Stefanos

kashif commented 5 years ago

@stangelid thanks for the issue! I think your intuition is right! I don't remember exactly my notes from the time I implemented this but thinking through it I think your way is correct. If you can send a PR I'll be happy to merge it, else I will put it into my TODO!

yassouali commented 5 years ago

I a question in the same context, can you please provide an explanation to why you apply Laplace smoothing to the cluster sizes _ema_cluster_size. I am having a hard time understanding why (to my knowledge it was not mentioned in the paper) Thanks.

stangelid commented 4 years ago

@yassouali According to my understanding, the laplace smoothing makes sure that no element of _ema_cluster_size will ever be exactly zero. If that ever happened, it would result in division with zero, when updating _ema_w.

I know this is late, but hope it helps :)

kashif commented 4 years ago

thanks @stangelid perhaps i'll add this explanation in the notebook

yassouali commented 4 years ago

Thank you.