facebookresearch / esm

Evolutionary Scale Modeling (esm): Pretrained language models for proteins
MIT License
3.16k stars 627 forks source link

Pretrained model unable to predict correctly in the test example? #16

Closed tueboesen closed 3 years ago

tueboesen commented 3 years ago

I'm assuming I'm doing something wrong in the following, but I can't really see what it could be so hopefully you guys can help point out what it is!

I'm hoping to use your model in my research, but first I wanted to do some validation of how it works. So rather than looking at the embedding layer, I'm looking at the actual output from the model. Which I assume can be parsed through a softmax function in order to return token probabilities, from which I can get the predicted amino acids by taking argmax.

However when I do that, I find that the returned probabilities makes no sense. What I'm doing is given in the code below. What am I doing wrong?

import torch
import esm
import numpy as np
# Load 34 layer model
model, alphabet = esm.pretrained.esm1_t34_670M_UR50S()
batch_converter = alphabet.get_batch_converter()

# Prepare data (two protein sequences)
data = [("protein1", "MYLYQKIKN"), ("protein2", "MNAKYD")]
batch_labels, batch_strs, batch_tokens = batch_converter(data)

aa = alphabet.all_toks
# model.model_version
# Extract per-residue embeddings (on CPU)
with torch.no_grad():
    results = model(batch_tokens)
    logits = results['logits']
    prob = torch.softmax(logits,dim=2)
    pred = torch.argmax(prob,dim=2)
    pred_str = aa[pred[0,0,]]
joshim5 commented 3 years ago

Hi @tueboesen, this is a great question! You're doing everything correctly. The results aren't terribly surprising because these are just made up protein sequences for illustration purposes :)

If you use a real protein sequence, the predictions will make more sense.

Relatedly, note that ESM-1 was only trained to predict amino acids at masked positions. If you introduce a mask at a position you'd like to predict, results will improve.

Lastly, I would highly encourage you to use the new model that we're announcing later today -- ESM-1b -- as it performs much better on this type of task! Even without masking any input sequences, it performs well. In your example, I switched esm1_t34_670M_UR50S for esm1b_t33_650M_UR50S and got the following outputs:

>>> pred
tensor([[ 0, 20, 19,  4, 19, 16, 15, 12, 15,  2],
        [ 0, 20, 17,  5, 15, 19,  2,  2,  2,  2]])
>>> batch_tokens
tensor([[32, 20, 19,  4, 19, 16, 15, 12, 15, 17],
        [32, 20, 17,  5, 15, 19, 13,  1,  1,  1]])

Feel free to reopen the issue if you have any additional questions.

tueboesen commented 3 years ago

Thank you for the quick reply!

I am a bit confused as to why your model isn't trained for predicting the unmasked amino acids? I can read in your paper that you only train on the actual masked parts of a protein, but I have to wonder why that is? My understand is that normally BERTs masked language model trains on the whole protein no? and wouldn't this result in an underlying embedding that is hard to trust for all the unmasked amino acids? Since they are effectively predicting the wrong amino acid.

In either case this is great, thanks for sharing.

joshim5 commented 3 years ago

Hi @tueboesen, these are great questions. The masked language modeling scheme in BERT does not train on the entire protein either. 15% of positions are selected and predicted. The other positions are not predicted.

To give some intuition on why ESM-1b is better at predicting the masked tokens, let's consider what happens to the 15% of those positions. in ESM-1, they are simply masked and predicted. In ESM-1b, 80% are masked; 10% are mutated to a random amino acid; and 10% remain the same. We speculate that the latter two operations allow ESM-1b to learn how to pass through the input amino acids to the output.

In the update to the ESM paper we posted this morning, there is a controlled analysis in the appendix that keeps everything constant and only varies the masking pattern. We find that the ESM-1b masking pattern performs better than ESM-1's, but not by much.

That begs the question - why does ESM-1 perform well if it can't even pass through the input? Here's my speculation: ESM-1 still learns interaction patterns between the amino acids because it's necessary to predict the 15% masked tokens. Even without masking, the model still applies those interactions and gives embeddings that perform well on downstream tasks.

The discrepancy between masking during pre-training but not masking during fine-tuning occurs throughout the literature, including in NLP. Whether this actually has an impact on the final performance is still an open question.