agemagician / ProtTrans

ProtTrans is providing state of the art pretrained language models for proteins. ProtTrans was trained on thousands of GPUs from Summit and hundreds of Google TPUs using Transformers Models.
Academic Free License v3.0
1.13k stars 153 forks source link

ProtT5 vs fine-tuned ProtBERT for classification #90

Closed wwwhhhccc closed 2 years ago

wwwhhhccc commented 2 years ago

Hello,

First off, thanks a lot for creating this project!

I'm trying to use the cls tokens generated to perform a downstream classification task. I tried fine-tuning ProtT5_XL, but there were simply too many parameters to tune and the colab pro+ gpu complained that it did not have sufficient memory. (Or perhaps I did something wrong and it is possible to fine-tune ProtT5_XL?)

My question is: to achieve best classification results, should I use the cls tokens given by ProtT5 directly without fine-tuning (i.e., using the model as a feature extractor) or should I fine-tune ProtBERT? I know this depends on my downstream classification task, but what do you think based on your experience? How high of an accuracy boost does fine-tuning usually give?

Thanks, Bill

mheinzinger commented 2 years ago

Hi,

tl;dr: so in our hands, fine-tuning usually did not buy us much compared to average-pooling.

Longer story: Sounds like your task is a protein-level task such as prediction of subcellular localization as opposed to residue-tasks which, for example, classify secondary structure elements for each residue in a protein. If that's correct, we made good experience with average-pooling (summing all residue embeddings and dividing by the number of residues) and using this fixed-size vector (usually 1024-d for our models) as input to a NN. What we did not experiment with: using the special token of ProtT5 (the last one) as input for such a scenario. Would be interesting to see whether this buys you sth. Besides this, there are certain techniques that allow you to fine-tune large models on limited hardware, e.g. https://pytorch-lightning.readthedocs.io/en/stable/advanced/model_parallel.html#deepspeed-zero-stage-3-offload Though, I did not work with those myself, yet, so I can not give you advice here.