evolutionaryscale / esm

Other
1.1k stars 121 forks source link

Understanding Encoder Update Mechanism in Structure VQ-VAE #38

Closed JeremieDona closed 1 month ago

JeremieDona commented 1 month ago

Hi ESM3 Team,

First of all, congratulations on your outstanding research work. I am particularly excited by the VQ-VAE structure proposed in your model.

Upon examining your code and detailed appendices, I observed that you use a Euclidean codebook to compute the quantized version of your codes. Additionally, you extract encoding indices which define the indices of your tokens which are then utilized by the decoder in an embedding layer.

My question pertains to the gradient flow in your model. Given that the argmin operation, which extracts the token indices, is non-differentiable, how is the encoder updated with respect to the reconstruction loss ? As far as I know, in vanilla VQVAE, a STE allows gradients to bypass the quantization step. However, in your implementation, it seems that the non-quantized version is not utilized in this manner. Could you please explain how the encoder receives gradients and is updated in your setup?

thayes427 commented 1 month ago

Thank you for your question!

You are correct that the argmin operation is non-differentiable, and that STE allows gradients to bypass the quantization step to flow back to the encoder. We apply this same technique in training the VQ-VAE. Please see this line in particular where gradients flow back to encoder outputs, bypassing the quantization: https://github.com/evolutionaryscale/esm/blob/17d48878a9cfad388fdf5ff4d3fe4ea0f0d24839/esm/layers/codebook.py#L79

Please let me know if you have any further questions.