MSA_batch_converter can align a batch of MSAs with different depths and lengths so that we can feed them together to the MSA transformer. During the row attention, the model calculates a scaling factor that is related to the number of rows:
However, the instances in the aligned batch may have different numbers of MSAs, resulting in an inconsistent representation compared to the one without padding. This inconsistency in representation can be exacerbated when the number of MSAs varies across cases in a batch.
the difference between representations of target_MSAs generated with MSAs_group1 and gt_MSA: tensor(0.2972, device='cuda:0')
the difference between representations of target_MSAs generated with MSAs_group2 and gt_MSA: tensor(0.0004, device='cuda:0')
MSA_batch_converter
can align a batch of MSAs with different depths and lengths so that we can feed them together to the MSA transformer. During the row attention, the model calculates a scaling factor that is related to the number of rows:https://github.com/facebookresearch/esm/blob/a01925e7508064e08a5b14e09f2284ef26418b81/esm/axial_attention.py#L36
However, the instances in the aligned batch may have different numbers of MSAs, resulting in an inconsistent representation compared to the one without padding. This inconsistency in representation can be exacerbated when the number of MSAs varies across cases in a batch.
Here is an example:
The output is:
It is obvious that the padding on the depth of MSAs causes a larger error. Actually I posted this issue before: https://github.com/facebookresearch/esm/issues/228