agemagician / ProtTrans

ProtTrans is providing state of the art pretrained language models for proteins. ProtTrans was trained on thousands of GPUs from Summit and hundreds of Google TPUs using Transformers Models.
Academic Free License v3.0
1.05k stars 150 forks source link

Formatting for ProtT5 labels #137

Closed exs-fdreyer closed 7 months ago

exs-fdreyer commented 7 months ago

Hello,

I have been trying to understand the protT5 model and how to compute a loss for the full encoder-decoder. Looking through github issues on this repository, it is suggested at multiple places that the format to predict masked residues should be, e.g. for a poly alanine sequence "AAAAA" input: "A A A " label: " A A"

which is similar to how HuggingFace describes T5 training: https://huggingface.co/docs/transformers/model_doc/t5#training

however, trying this results in a substantially worse loss than simply using the original sequence as label. E.g., running the following code:

from transformers import T5Tokenizer
from transformers.models.t5.modeling_t5 import T5ForConditionalGeneration

tokenizer = T5Tokenizer.from_pretrained('Rostlab/prot_t5_xl_uniref50', do_lower_case=False)
model = T5ForConditionalGeneration.from_pretrained('Rostlab/prot_t5_xl_uniref50')
# input sequences "EVQLVESGAE" and "AAAAAAAAAA"
label_seq = tokenizer(["E V Q L V E S G A E", "A A A A A A A A A A"], return_tensors="pt").input_ids
# mask some of the residues with sentinel tokens
input_seq = tokenizer(["E V <extra_id_0> L <extra_id_1> E S G <extra_id_2> E", "A A <extra_id_0> A A <extra_id_1> A A <extra_id_2> A"], return_tensors="pt").input_ids
# suggested format for the labelling of masked tokens ("<extra_id_0> Q <extra_id_1> V <extra_id_2> A", etc)
label_seq_alt = tokenizer(["<extra_id_0> Q <extra_id_1> V <extra_id_2> A", "<extra_id_0> A <extra_id_1> A <extra_id_2> A"], return_tensors="pt").input_ids

print(model(input_ids=input_seq, labels=label_seq).loss)
print(model(input_ids=input_seq, labels=label_seq_alt).loss)

shows a negative log likelihood loss of 1.2 for the first and 40 for the second case, with the first one going down as expected as the number of masked residues is reduced, while the second one stays roughly constant. This makes me think that the correct way to further pre-train the model would be to pass the full unmasked sequence as label rather than the masked tokens, is that correct?

agemagician commented 7 months ago

Hello,

On section 2.4 under "ProtT5" on our paper, we have mentioned the following:

Contrary to the original T5 model which masks spans of multiple tokens, we adopted BERT’s denoising objective to corrupt and reconstruct single tokens using a masking probability of 15 percent.

So we followed Bert style nosing and denosing wirh single sentinel. This means if the original sequence is ""E V Q L V E S G A E", then: