facebookresearch / esm

Evolutionary Scale Modeling (esm): Pretrained language models for proteins
MIT License
3.16k stars 627 forks source link

fine-tune model with language model head #15

Closed wjs20 closed 3 years ago

wjs20 commented 3 years ago

Hi

I would like to fine-tune esm with a language model head. I tried

import torch
model = torch.hub.load("facebookresearch/esm", "modelWithLMHead", "esm1_t34_670M_UR50S")  

but I got RuntimeError: Cannot find callable modelWithLMHead in hubconf

Is there a simple way to do this? Thanks

joshim5 commented 3 years ago

Hi @wjs20, you don't need to specify modelWithLMHead. Following our README, simply do:

import torch
model = torch.hub.load("facebookresearch/esm", "esm1_t34_670M_UR50S")  

Then, when you call result = model(...) or result = model.forward(), you'll get a dictionary where result['logits'] contains the output from the language modeling head. To fine-tune ESM with a language modeling head, you should setup your loss with respect to that output.

brucejwittmann commented 3 years ago

Thanks for making these models available. They are very useful!

I'm also trying to further fine-tune the provided models using masked language modeling for a specific protein family and have just a few questions on the training procedure (commenting here because they pertain to your previous answer):

1) I noticed that <eos> is an available token in the alphabet provided, but calling batch_converter(data) returns tokens with only <cls> and <pad>. Should the <eos> token be added to the end before feeding to the model, or was this an unused token for language modeling during initial training? 2) The accompanying publication states "Our model was pre-trained using a context size of 1024 tokens". Should <pad> tokens be appended to sequences that would otherwise be shorter than 1024 tokens? The batch_converter only adds padding tokens up to the length of the longest sequence in the batch. 3) Based on the output, I believe this is the case, but just to confirm: Are the outputs given by results["logits"] unnormalized scores? In other words, can they be passed directly into an instance of torch.nn.CrossEntropyLoss() without further modification?

joshim5 commented 3 years ago

Hi @brucejwittmann, thanks for the great questions!

  1. We did not use the <eos> token during pretraining for the ESM-1 models.
  2. No, you don't need to append <pad> tokens. The batch_converter only adds these tokens so that sequences of different length can be included in the same batch. When we pre-train the models, the loss function ignores any <pad> positions. This means that we get the same loss for any arbitrary number of <pad> tokens. For sequences longer than 1023 tokens, we used a random crop of 1023 tokens and then pre-pended a <cls> token, for a total of 1024 tokens.
  3. Yes! They are unnormalized logits that can be passed to torch.nn.CrossEntropyLoss(). Just make sure to ignore pad tokens with ignore_index so that they don't contribute to the loss.
brucejwittmann commented 3 years ago

Hi @joshim5 , thanks for your quick and helpful response! Just to clarify on the use of ignore_index: My understanding from your paper is that loss was calculated for predictions made for masked tokens only. Does this mean that <pad> tokens were sometimes the ones that were masked? I was planning to design my masking function such that it never masks a padding token (in other words, it knows the length of each given protein and just masks amino acid tokens). If I were to do that, my understanding is that ignore_index wouldn't be needed as <pad> could never be a target. I suppose I have a few follow-up questions, then:

1) Was loss calculated against more than the masked tokens in your original work? 2) Were <pad> tokens masked in the original work? If so, is this because there is a downside to restricting padding to amino-acid tokens only?

Thanks again!

joshim5 commented 3 years ago

Hi @brucejwittmann, to quickly answer your questions:

  1. No, not for pretraining
  2. No You asked about <pad> tokens in the batch_converter, so I brought up ignore_index as a warning just in case you implement it in a way where masks could be introduced on <pad> tokens. However, the design plan you just described sounds great and ignore_index won't be needed in that case.