zalandoresearch / pytorch-vq-vae

PyTorch implementation of VQ-VAE by Aäron van den Oord et al.
MIT License
534 stars 101 forks source link

How does the reconstruction loss update the encoder? #6

Closed eecshope closed 4 years ago

eecshope commented 4 years ago

Thanks for your implementation of VQ-VAE but I've got a question. In the original paper, the gradients to the inputs of the decoder have been copied to the encoder's output cuz the op 'index selection' is non-differentiable, but I didn't the corresponding implementation in your code. I'm new in pytorch and not familiar with the auto-grad system, so it'll be appreciatable to have a little explanation about this. Thanks!

kashif commented 4 years ago

The quantized.detach() operation essentially does the stop gradient if I remember correctly. Hope that clears it up?

eecshope commented 4 years ago

I understand. The code

quantized = inputs + (quantized - inputs).detach()

has already copied the gradient to Z_q to Z_e by regarding the latter term as a constant number. Awesome! Thank you so much!