facebookresearch / esm

Evolutionary Scale Modeling (esm): Pretrained language models for proteins
MIT License
2.99k stars 585 forks source link

Demo (ESM-MSA-1b for variant prediction) yields inconsistent mutant positions between original sequence and processed MSA #470

Open kerrding opened 1 year ago

kerrding commented 1 year ago

NOTE: if this is not a bug report, please use the GitHub Discussions for support questions (How do I do X?), feature requests, ideas, showcasing new applications, etc.

Bug description Please enter a clear and concise description of what the bug is.

Hi there! I am using the pretrained esm-msa-1b model for variant prediction using examples/variant-prediction/predict.py. It seems that the MSA transformer takes a processed MSA input (insertions removed) but for variant prediction it still takes the mutant column related to the original sequence for calculation. That is, some of the residues in the original sequence are dropped in MSA preprocessing, but when calculating the masked marginal probabilities for user-defined mutants, examples/variant-prediction/predict.py still uses the mutant positions in the original sequence for calculation, which might lead to an incorrect calculation.

For example, if we were to calculate 'A23B' from the original sequence, 'A23B' is now 'A15B' in the processed MSA but current code still calculates 'A23B', which is the probability of another position (36 for example). This can lead to incorrect calculations of single mutant variant predictions.

Reproduction steps Enter steps to reproduce the behavior.

This can be reproduced by using any input. I am not allowed to expose our data hence I am using the example in examples/variant-prediction/README.md to illustrate.

python predict.py \
    --model-location esm_msa1b_t12_100M_UR50S \
    --sequence HPETLVKVKDAEDQLGARVGYIELDLNSGKILESFRPEERFPMMSTFKVLLCGAVLSRVDAGQEQLGRRIHYSQNDLVEYSPVTEKHLTDGMTVRELCSAAITMSDNTAANLLLTTIGGPKELTAFLHNMGDHVTRLDRWEPELNEAIPNDERDTTMPAAMATTLRKLLTGELLTLASRQQLIDWMEADKVAGPLLRSALPAGWFIADKSGAGERGSRGIIAALGPDGKPSRIVVIYTTGSQATMDERNRQIAEIGASLIKHW \
    --dms-input ./data/BLAT_ECOLX_Ranganathan2015.csv \
    --mutation-col mutant \
    --dms-output ./data/BLAT_ECOLX_Ranganathan2015_labeled.csv \
    --offset-idx 24 \
    --scoring-strategy masked-marginals \
    --msa-path ./data/BLAT_ECOLX_1_b0.5.a3m

Expected behavior Give a clear and concise description of what you expected to happen.

I expect that the original single mutant can be translated to new processed MSA for single mutant prediction. Also, I believe there might need to be a note somewhere to remind users of this problem that needs to be taken care of. Currently, I have temporarily fixed this problem using my own code. I am sure there can be more efficient and faster code for this.

To be short, I create a mapping between the original sequence and the processed MSA reference sequence so that the variant-prediction can be calculated to the exact position that user cares. And I am using str(msa[0][1]) the reference sequence of the processed MSA instead of args.sequence the original sequence for validation, in line 181 of examples/variant-prediction/predict.py.

# sequence - original sequence
# alphabet - a set with only uppercase residue abbreviations {'A', 'F', ...}
new_pos_map = dict({})
cnt = 1
for i,s in enumerate(sequence):
    if s in alphabet:
        new_pos_map[i+1] = cnt
        cnt += 1
    else:
        new_pos_map[i+1] = -1

Logs Please paste the command line output:

Additional context Add any other context about the problem here. (like proxy settings, network setup, overall goals, etc.)

I believe that this is only a problem in the demo and, in application, users can be expected to write their own codes to avoid this problem. Also, I am happy to make contributions to fix this problem if it is indeed a problem of what I have viewed it as.

Amelie-Schreiber commented 11 months ago

I implemented the scoring functions from the paper on zero-shot variant predictions using ESM-2 which does not require MSA. If this might be of help, please see this repo. There is an example of how to use the scoring functions in this notebook.