agemagician / ProtTrans

ProtTrans is providing state of the art pretrained language models for proteins. ProtTrans was trained on thousands of GPUs from Summit and hundreds of Google TPUs using Transformers Models.
Academic Free License v3.0
1.13k stars 153 forks source link

Fine-Tune T5 for Residue Classification #117

Closed MoaazK closed 1 year ago

MoaazK commented 1 year ago

Hi,

I was wondering if I can fine-tune ProtT5-XL-UniRef for token classification. I got an example of ProtBERT finetuning for SS3 which uses AutoModelForTokenClassification, BERTTokenizerFast, etc. Can you please guide on what changes can I make in the notebook (https://github.com/agemagician/ProtTrans/blob/master/Fine-Tuning/ProtBert-BFD-FineTune-SS3.ipynb) for T5.

Secondly, Quick Start (https://github.com/agemagician/ProtTrans/tree/master#-quick-start) section gives an example of feature extraction by T5 which uses mean() to get sequence level embeddings. However, in the case of ProtBERT (AutoModelForSequenceClassification), it uses a linear layer to get embedding from CLS token for whole sequence. What is the difference between getting features from mean() and CLS token?

Thanks in advance

mheinzinger commented 1 year ago

Hi,

yes, you can fine-tune ProtT5-XL-U50 for token classification. You can try to plug in the the T5EncoderModel which only uses only the encoder-part of T5. This way you have a very similar setup as BERT (Encoder-only but with learnt positional encoding). Unfortunately, we have no notebook for fine-tuning ProtT5 for this which I could share but we are working on it and I will update the README once we are there.

On your second Q: the CLS token is not conditioned/used in our case. The original implementation used this token for classifying whether two sentences follow up on each other in a document or not. This concept does not exist for proteins so we dropped it. The CLS token is an artifact of this. Given that it was never conditioned on anything it is also hard to tell what it learnt. In our hands it was better to apply average-pooling over the whole sequence. As you have mentioned ProtT5 anyway: I would strongly recommend to simply average-pool over ProtT5 embeddings instead of using ProtBERT.