facebookresearch / esm

Evolutionary Scale Modeling (esm): Pretrained language models for proteins
MIT License
2.99k stars 587 forks source link

MSA batch converter does not accept pre-masked MSAs #234

Open seanrjohnson opened 1 year ago

seanrjohnson commented 1 year ago

Bug description

Reproduction steps

import esm
model, alphabet = esm.pretrained.esm_msa1b_t12_100M_UR50S()
batch_converter = alphabet.get_batch_converter()

batch_size=5
padded_msa = [('0', 'AAA<mask><mask>'), ('1', 'ACC<mask><mask>'), ('2', 'ACDE<mask>')]
labels, strs, tokens = batch_converter([padded_msa] * batch_size)

Expected behavior It should execute without crashing

Logs

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/sean/miniconda3/envs/protein_gibbs_sampler/lib/python3.9/site-packages/esm/data.py", line 324, in __call__
    raise RuntimeError(
RuntimeError: Received unaligned sequences for input to MSA, all sequence lengths must be equal.

Additional context

I think the function here needs to be augmented: https://github.com/facebookresearch/esm/blob/e5e7b06b9a093706607c229ab1c5c9821806814d/esm/data.py#L322

def rawbatchlen(raw_batch: str): # ADDED
    count = 0
    counting = True
    for ch in raw_batch:
        if ch == "<":
            counting = False
        if ch == ">":
            counting = True
        if counting == True:
            count += 1
    return count

class MSABatchConverter(esm.data.BatchConverter):
    def __call__(self, inputs: Union[Sequence[RawMSA], RawMSA]):
        if isinstance(inputs[0][0], str):
            # Input is a single MSA
            raw_batch: Sequence[RawMSA] = [inputs]  # type: ignore
        else:
            raw_batch = inputs  # type: ignore

        batch_size = len(raw_batch)
        max_alignments = max(len(msa) for msa in raw_batch)
        max_seqlen = max(rawbatchlen(msa[0][1]) for msa in raw_batch)  ### CHANGED
        tokens = torch.empty(
            (
                batch_size,
                max_alignments,
                max_seqlen + int(self.alphabet.prepend_bos) + int(self.alphabet.append_eos),
            ),
            dtype=torch.int64,
        )
        tokens.fill_(self.alphabet.padding_idx)
        labels = []
        strs = []
        for i, msa in enumerate(raw_batch):
            msa_seqlens = set(rawbatchlen(seq) for _, seq in msa) ### CHANGED
            if not len(msa_seqlens) == 1:
                raise RuntimeError(
                    "Received unaligned sequences for input to MSA, all sequence "
                    "lengths must be equal."
                )
            msa_labels, msa_strs, msa_tokens = super().__call__(msa)
            labels.append(msa_labels)
            strs.append(msa_strs)
            tokens[i, : msa_tokens.size(0), : msa_tokens.size(1)] = msa_tokens
        return labels, strs, tokens

Admittedly the uses cases for this kind of thing are pretty limited. I'm using it to support sequence generation beyond the bounds of the seed MSA. (Which by the way does not work very well, but I have a bunch of test cases that assume such functionality, and it's easier for me to monkeypatch your batch converter than to remove that feature from my package and adjust the test cases).

tomsercu commented 1 year ago

Ah that's a good catch. Reason is that we added support for tokenization of the <MASK> kind later. Checking the matching sequence lengths should happen only after self.alphabet.tokenize(seq), which happens also inside the super.__call__. Maybe best to do those length-checks and coalesce into the tokens batch-tensor only after we get msa_tokens back from there

italobale commented 1 year ago

Hi, I don't know if this is a related issue, but I get the following error when running the example code for the Zero-shot variant prediction with MSA transformer:

python predict.py \   
    --model-location esm_msa1b_t12_100M_UR50S \
    --sequence HPETLVKVKDAEDQLGARVGYIELDLNSGKILESFRPEERFPMMSTFKVLLCGAVLSRVDAGQEQLGRRIHYSQNDLVEYSPVTEKHLTDGMTVRELCSAAITMSDNTAANLLLTTIGGPKELTAFLHNMGDHVTRLDRWEPELNEAIPNDERDTTMPAAMATTLRKLLTGELLTLASRQQLIDWMEADKVAGPLLRSALPAGWFIADKSGAGERGSRGIIAALGPDGKPSRIVVIYTTGSQATMDERNRQIAEIGASLIKHW \
    --dms-input ./data/BLAT_ECOLX_Ranganathan2015.csv \
    --mutation-col mutant \
    --dms-output ./data/BLAT_ECOLX_Ranganathan2015_labeled.csv \
    --offset-idx 24 \
    --scoring-strategy masked-marginals \
    --msa-path ./data/BLAT_ECOLX_1_b0.5.a3m
Traceback (most recent call last):
  File "/Users/ibalestra/Data/esm/examples/variant-prediction/predict.py", line 241, in <module>
    main(args)
  File "/Users/ibalestra/Data/esm/examples/variant-prediction/predict.py", line 167, in main
    batch_labels, batch_strs, batch_tokens = batch_converter(data)
  File "/Users/ibalestra/opt/miniforge3/envs/esm/lib/python3.9/site-packages/esm/data.py", line 327, in __call__
    raise RuntimeError(
RuntimeError: Received unaligned sequences for input to MSA, all sequence lengths must be equal.
acforvs commented 1 year ago

Hi @tomsercu! I would love to address this issue if it's possible. However, I do not think that this would work

Maybe best to do those length-checks and coalesce into the tokens batch-tensor only after we get msa_tokens back from there

The reason for this is that the output of the super().__call__ would return tokens with the same length because of the alignment https://github.com/facebookresearch/esm/blob/main/esm/data.py#L269

For example,

import esm
model, alphabet = esm.pretrained.esm_msa1b_t12_100M_UR50S()
data = [
    ("protein1", "MKTVRQG"),
    ("protein4", "AAAAAA<mask><mask>"),
    ]
BatchConverter(alphabet)(data)

returns tokens of length 9 each (while MKTVRQG should have len 8)

tensor([[ 0, 20, 15, 11,  7, 10, 16,  6,  1],
         [ 0,  5,  5,  5,  5,  5,  5, 32, 32]]))

All in all, I think that we should use something like rawbatchlen as described in the original message or have a special flag to check that the length is all the same here https://github.com/facebookresearch/esm/blob/main/esm/data.py#L266.

What do you think would be a better option? Thanks!

outongyiLv commented 1 year ago

嗨,我不知道这是否是一个相关问题,但在使用 MSA 转换器运行零样本变量预测的示例代码时出现以下错误:

python predict.py \   
    --model-location esm_msa1b_t12_100M_UR50S \
    --sequence HPETLVKVKDAEDQLGARVGYIELDLNSGKILESFRPEERFPMMSTFKVLLCGAVLSRVDAGQEQLGRRIHYSQNDLVEYSPVTEKHLTDGMTVRELCSAAITMSDNTAANLLLTTIGGPKELTAFLHNMGDHVTRLDRWEPELNEAIPNDERDTTMPAAMATTLRKLLTGELLTLASRQQLIDWMEADKVAGPLLRSALPAGWFIADKSGAGERGSRGIIAALGPDGKPSRIVVIYTTGSQATMDERNRQIAEIGASLIKHW \
    --dms-input ./data/BLAT_ECOLX_Ranganathan2015.csv \
    --mutation-col mutant \
    --dms-output ./data/BLAT_ECOLX_Ranganathan2015_labeled.csv \
    --offset-idx 24 \
    --scoring-strategy masked-marginals \
    --msa-path ./data/BLAT_ECOLX_1_b0.5.a3m
Traceback (most recent call last):
  File "/Users/ibalestra/Data/esm/examples/variant-prediction/predict.py", line 241, in <module>
    main(args)
  File "/Users/ibalestra/Data/esm/examples/variant-prediction/predict.py", line 167, in main
    batch_labels, batch_strs, batch_tokens = batch_converter(data)
  File "/Users/ibalestra/opt/miniforge3/envs/esm/lib/python3.9/site-packages/esm/data.py", line 327, in __call__
    raise RuntimeError(
RuntimeError: Received unaligned sequences for input to MSA, all sequence lengths must be equal.

I met this problem too.

Jeffuuuu commented 5 months ago

Hi, I don't know if this is a related issue, but I get the following error when running the example code for the Zero-shot variant prediction with MSA transformer:

python predict.py \   
    --model-location esm_msa1b_t12_100M_UR50S \
    --sequence HPETLVKVKDAEDQLGARVGYIELDLNSGKILESFRPEERFPMMSTFKVLLCGAVLSRVDAGQEQLGRRIHYSQNDLVEYSPVTEKHLTDGMTVRELCSAAITMSDNTAANLLLTTIGGPKELTAFLHNMGDHVTRLDRWEPELNEAIPNDERDTTMPAAMATTLRKLLTGELLTLASRQQLIDWMEADKVAGPLLRSALPAGWFIADKSGAGERGSRGIIAALGPDGKPSRIVVIYTTGSQATMDERNRQIAEIGASLIKHW \
    --dms-input ./data/BLAT_ECOLX_Ranganathan2015.csv \
    --mutation-col mutant \
    --dms-output ./data/BLAT_ECOLX_Ranganathan2015_labeled.csv \
    --offset-idx 24 \
    --scoring-strategy masked-marginals \
    --msa-path ./data/BLAT_ECOLX_1_b0.5.a3m
Traceback (most recent call last):
  File "/Users/ibalestra/Data/esm/examples/variant-prediction/predict.py", line 241, in <module>
    main(args)
  File "/Users/ibalestra/Data/esm/examples/variant-prediction/predict.py", line 167, in main
    batch_labels, batch_strs, batch_tokens = batch_converter(data)
  File "/Users/ibalestra/opt/miniforge3/envs/esm/lib/python3.9/site-packages/esm/data.py", line 327, in __call__
    raise RuntimeError(
RuntimeError: Received unaligned sequences for input to MSA, all sequence lengths must be equal.

Hello everyone,

currently working on the Zero-Shot variant prediction with MSA-Transformer, I met the same error as described in this message, when running the same command lines. Has the source of this error been pinpointed ?

Thank you in advance for your response, Jeffuuuu