facebookresearch / esm

Evolutionary Scale Modeling (esm): Pretrained language models for proteins
MIT License
2.99k stars 587 forks source link

Inconsistent result between multi-MSAs and single-MSAs #228

Open huangtinglin opened 2 years ago

huangtinglin commented 2 years ago

Bug description I am running the pretrained MSA transformer (esm_msa1b_t12_100M_UR50S) on some MSAs with different numbers and lengths to generate the representations. Following the example shown in README, I apply batch_converter to process the MSAs and obtain the token tensor with padding. But the representations generated by the transformer don't match the results when the MSAs are fed into the model one at a time.

Reproduction steps Here is a simple example.

import esm
import torch

model, alphabet = esm.pretrained.esm_msa1b_t12_100M_UR50S()
batch_converter = alphabet.get_batch_converter()
model.eval()

MSAs1 = [
            ("protein1 test", "MKTVRQERLKSIVRILERSKEPVSGAQLAE"),
            ("protein2 test", "KALTARQQEVFDLIRDHISQTGMPPTRAEI"),
            ("protein3 test","KALTARQQEVFDLIRDBISQTGMPPTRAEI"),
        ]
MSAs2 = [
            ("protein1", "MKTVRQERLKSIVRILERSKEPVSGAAA"),
            ("protein2", "KALTARQQEVFDLIRDHISQTGMPPCDC"),
        ]

MSAs_group1 = [MSAs1, MSAs2]
MSAs_group2 = [MSAs2]
_, _, batch_tokens1 = batch_converter(MSAs_group1)
_, _, batch_tokens2 = batch_converter(MSAs_group2)

with torch.no_grad():
    all_info = model(batch_tokens1, repr_layers=[12], need_head_weights=True)
    all_info1 = model(batch_tokens2, repr_layers=[12], need_head_weights=True)

repres1 = all_info["representations"][12]  # [2, 3, 31, 768]
repres2 = all_info1["representations"][12]  # [1, 2, 29, 768]

repres2_shape = repres2.shape
repres1 = repres1[1:, :repres2_shape[1], :repres2_shape[2]]  # [1, 2, 29, 768]
print("the difference between representations of MSAs2 generated with MSAs_group1 and MSAs_group2: ", (repres1 - repres2).sum())

Expected behavior A target MSA's representation produced by feeding it into the MSA transformer with and without the other MSAs is identical.

Logs

the difference between representations of MSAs2 generated with MSAs_group1 and MSAs_group2:  tensor(3.8772)

Additional context Add any other context about the problem here. (like proxy settings, network setup, overall goals, etc.)

brucejwittmann commented 1 year ago

@huangtinglin, in testing some code I've been writing recently using esm1_t6, I've noticed that different batch sizes can sometimes give you slightly different results. I'm extracting embeddings (CPU bound), and I observe the behavior whether I'm running the extract.py script from ESM or my own code that wraps the model. My best guess right now is that it's something to do with which algorithms PyTorch uses based on workload (see here). What are some summary stats on your differences other than the sum (mean, median, min, max)? If they're small, then I'd be curious if your observations are due to floating point errors with different algorithms being used for the different workloads. The differences that I typically observe are around 1e-5 or smaller -- MSA1b is much larger than esm1_t6, though, so I wouldn't be surprised if slightly larger differences could happen with MSA1b as it performs more operations to calculate the representations.

huangtinglin commented 1 year ago

Thanks, @brucejwittmann. Actually, I found that the error is due to the scaling factor which is related to the number of rows. I have created a new issue regarding this matter, which can be found at https://github.com/facebookresearch/esm/issues/491.