facebookresearch / esm

Evolutionary Scale Modeling (esm): Pretrained language models for proteins
MIT License
3.16k stars 627 forks source link

Incorrect sequence length calculation in usage example #319

Closed kklemon closed 1 year ago

kklemon commented 1 year ago

Bug description

The usage example in the project's README uses the length of the original protein sequence to determine the slicing indices for the model output with respect to padding. However, the provided code implies that the tokenizer is flexible and allows, for instance, the use of spaces between amino acids or even special tokens such as <mask>. The string length of the original protein sequence may therefore not correspond to the tokenized sequence.

The example code does calculate the slicing indices in a wrong way, particularly for protein2 with mask and protein3.

Note: this issue may also be present in the notebooks. I haven't checked them so far.

Reproduction steps

Take the following example which is derived from the usage example:

batch_1 = [('prot1', 'MKTV')]
batch_2 = [('prot2', 'M K T V')]

*_, batch_tokens_1 = batch_converter(batch_1)
*_, batch_tokens_2 = batch_converter(batch_2)

# Spaces in 'prot2' are ignored.
# Both samples are tokenized to the same result
assert (batch_tokens_1 == batch_tokens_2).all()

# However, string lengths as used by the README example differ and cannot 
# be used to determine slicing indices of the output sequence
assert len(batch_1[0][1]) == len(batch_2[0][1])

The code in the README implies that the last assertion should hold but obviously it does not.

Expected behavior

The usage example should calculate the sequence lengths with respect to padding correctly. This may be achieved by counting the non-padding tokens in the tokenized sequence. In the following is a corrected version of the original usage example:

...

seq_lengths = (batch_tokens != alphabet.padding_idx).sum(1)

# Generate per-sequence representations via averaging
# NOTE: token 0 is always a beginning-of-sequence token, so the first residue is token 1.
sequence_representations = []
for i, seq_len in enumerate(seq_lengths):
    sequence_representations.append(token_representations[i, 1 : seq_len - 1].mean(0))

# Look at the unsupervised self-attention map contact predictions
import matplotlib.pyplot as plt
for seq_len, attention_contacts in zip(seq_lengths, results["contacts"]):
    plt.matshow(attention_contacts[: seq_len, : seq_len])
    plt.title(seq)
    plt.show()
nikitos9000 commented 1 year ago

@kklemon Thanks for finding this issue! Here's the related PR fixing all the encounters https://github.com/facebookresearch/esm/pull/361