facebookresearch / esm

Evolutionary Scale Modeling (esm): Pretrained language models for proteins
MIT License
3.16k stars 627 forks source link

Token dropout is set to Ture duing inference? #267

Closed zhenyuhe00 closed 2 years ago

zhenyuhe00 commented 2 years ago

Hi, Congrats on your great works! In your ESM-2 model code, I noticed that "token_dropout" is set to True. I wonder why is that. Thanks in advance!

nikitos9000 commented 2 years ago

Hi @zhenyuhe00 , that's just an implementation of Masked LM training, as in the paper 15% of amino acids are masked. During the inference no tokens are masked so the embeddings are scaled accordingly here https://github.com/facebookresearch/esm/blob/main/esm/model/esm2.py#L86. I hope that answers your question

dan-sprague commented 2 years ago

I haven't come across scaling embeddings during inference for these types of models before. Is there a reference for doing that or is it something you guys did?

tomsercu commented 2 years ago

Token drop was introduced in our LMs in the ESM-1b paper (see supplement of PNAS paper Rives et al). The idea behind it is that we're fully zero-ing out the embeddings instead of using a non-zero embedding for the <mask> token. Then the scaling logic is simply by direct analogy to standard dropout, which is motivated by having the same vector norm when dropout is turned on or off:

Furthermore, the outputs are scaled by a factor of 1/(1−p) during training. This means that during evaluation the module simply computes an identity function.

Overall token drop had a very minor effect. And as @nikitos9000 pointed out during inference there's no rescaling happening.

zhenyuhe00 commented 2 years ago

Thanks a lot!

jasperhyp commented 1 year ago

Hi, To follow up on this issue, I noticed in the appendix that not all 15% of excluded tokens are transformed into mask_idx. Some are transformed into other tokens, and some are kept in their original form. So these tokens are not dropped out, correct?

Also, the dropped out tokens would be the same as padding tokens, and so the final embedding would then be dependent on the bias term in each operation mainly?