facebookresearch / esm

Evolutionary Scale Modeling (esm): Pretrained language models for proteins
MIT License
3.26k stars 643 forks source link

Padding in MSA can lead to inconsistent results compared to MSA without padding #491

Open huangtinglin opened 1 year ago

huangtinglin commented 1 year ago

MSA_batch_converter can align a batch of MSAs with different depths and lengths so that we can feed them together to the MSA transformer. During the row attention, the model calculates a scaling factor that is related to the number of rows:

https://github.com/facebookresearch/esm/blob/a01925e7508064e08a5b14e09f2284ef26418b81/esm/axial_attention.py#L36

However, the instances in the aligned batch may have different numbers of MSAs, resulting in an inconsistent representation compared to the one without padding. This inconsistency in representation can be exacerbated when the number of MSAs varies across cases in a batch.

Here is an example:

import esm
import torch

model, alphabet = esm.pretrained.esm_msa1b_t12_100M_UR50S()
batch_converter = alphabet.get_batch_converter()
model.eval().to(0)

MSAs1 = [
            ("protein1 test", "MKTVRQERLKSIVRILERSKEPVSGAQLAE"),
            ("protein2 test", "KALTARQQEVFDLIRDHISQTGMPPTRAEI"),
            ("protein3 test","KALTARQQEVFDLIRDBISQTGMPPTRAEI"),
            ("protein3 test","KALTARQQEVFDLIRDBISQTGMPPTRAEI"),
            ("protein3 test","KALTARQQEVFDLIRDBISQTGMPPTRAEI"),
            ("protein3 test","KALTARQQEVFDLIRDBISQTGMPPTRAEI"),
        ]
MSAs2 = [
            ("protein1 test", "MKTVRQERLKSIVRILERSKEPVSGAQLAE"),
            ("protein2 test", "KALTARQQEVFDLIRDHISQTGMPPTRAEI"),
        ]
target_MSAs = [
            ("protein1", "MKTVRQERLKSIVRILERSKEPVSGAAA"),
            ("protein2", "KALTARQQEVFDLIRDHISQTGMPPCDC"),
        ]

MSAs_group1 = [MSAs1, target_MSAs]  # a batch of MSAs with padding
MSAs_group2 = [MSAs2, target_MSAs]  # a batch of MSAs with padding
gt_MSAs_group = [target_MSAs]  # without padding

_, _, batch_tokens1 = batch_converter(MSAs_group1)
_, _, batch_tokens2 = batch_converter(MSAs_group2)
_, _, gt_batch_tokens = batch_converter(gt_MSAs_group)
batch_tokens1 = batch_tokens1.to(0)
batch_tokens2 = batch_tokens2.to(0)
gt_batch_tokens = gt_batch_tokens.to(0)

with torch.no_grad():
    all_info1 = model(batch_tokens1, repr_layers=[12], need_head_weights=False)
    all_info2 = model(batch_tokens2, repr_layers=[12], need_head_weights=False)
    gt_all_info = model(gt_batch_tokens, repr_layers=[12], need_head_weights=False)

repres1 = all_info1["representations"][12]  # [2, 6, 31, 768]
repres2 = all_info2["representations"][12]  # [2, 2, 31, 768]
gt_repres = gt_all_info["representations"][12]  # [1, 2, 29, 768]

print(repres1.shape, repres2.shape, gt_repres.shape)
gt_shape = gt_repres.shape
repres1 = repres1[1:, :gt_shape[1], :gt_shape[2]]  # [1, 2, 29, 768]
repres2 = repres2[1:, :gt_shape[1], :gt_shape[2]]  # [1, 2, 29, 768]
print("the difference between representations of target_MSAs generated with MSAs_group1 and gt_MSA: ", (repres1 - gt_repres).abs().mean())
print("the difference between representations of target_MSAs generated with MSAs_group1 and gt_MSA: ", (repres2 - gt_repres).abs().mean())

The output is:

the difference between representations of target_MSAs generated with MSAs_group1 and gt_MSA:  tensor(0.2972, device='cuda:0')
the difference between representations of target_MSAs generated with MSAs_group2 and gt_MSA:  tensor(0.0004, device='cuda:0')

It is obvious that the padding on the depth of MSAs causes a larger error. Actually I posted this issue before: https://github.com/facebookresearch/esm/issues/228