ritheshkumar95 / pytorch-vqvae

Vector Quantized VAEs - PyTorch Implementation
850 stars 138 forks source link

Bug in VQ function? #7

Closed pfriesch closed 6 years ago

pfriesch commented 6 years ago

Nice implementation of the VQ-Straight through function!

However, when looking at the autograd graph there is an edge that is breaking the separated gradients for the reconstruction loss and the VQ loss. So the reconstruction loss is also updating the embedding, which should not happen. I tried to figure out why that happens. My understanding of pytorch isn't that thorough though. Do you might have an idea?

I marked the edge here.

ritheshkumar95 commented 6 years ago

Hi @pfriesch,

Thanks for pointing it out, in the previous version of the code, without the VQ-ST function (https://github.com/ritheshkumar95/pytorch-vqvae/blob/cde142670f701e783f29e9c815f390fc502532e8/vqvae.py#L75) we actually explicitly made sure that the reconstruction loss didn't update the codebook.

We missed it when implementing the VQ-ST function. We're quickly fixing it.

However, i'd like to point out that it didn't make much of a difference (reconstructions-wise, ELBO-wise), to do / do not update codebook with the reconstruction loss.

tristandeleu commented 6 years ago

This is something I did (incorrectly) on purpose, but you are absolutely right the paper explicitly says that the reconstruction loss does not update the embeddings. I fixed this. Here is what the graph now looks like. This is a simplified version, where the left path is loss_commit, middle path is loss_recons and right path is loss_vq. image Here is the full graph.

pfriesch commented 6 years ago

Cool! Thanks!

However, i'd like to point out that it didn't make much of a difference (reconstructions-wise, ELBO-wise), to do / do not update codebook with the reconstruction loss.

Interesting, how does that work though? I am trying to wrap my head around that. That would mean the selected(closest) embeddings gets moved towards better fitting embeddings for the decoder, based on the reconstruction loss. Which in itself would not hurt the training of the vector quantized embedding and it might even speed up training. Would you agree with that intuition?

ritheshkumar95 commented 6 years ago

Yes that is the same understanding that i had as well.