agemagician / ProtTrans

ProtTrans is providing state of the art pretrained language models for proteins. ProtTrans was trained on thousands of GPUs from Summit and hundreds of Google TPUs using Transformers Models.
Academic Free License v3.0
1.13k stars 153 forks source link

Using [CLS] token for classification task? #88

Closed gitwangxf closed 2 years ago

gitwangxf commented 2 years ago

Hello! Thanks for your great work! I'm trying to do some classification tasks on protein-protein interaction. I was wondering that as Next Sentence Prediction was not used when training these models, what did the [CLS] from the last hidden state output stand for? Could they also capture import biophysical properties of the input proteins and thus be used in downstream predicition tasks?

mheinzinger commented 2 years ago

You are right: we did not train the [CLS] token for any specific task due to the lack of a "next-sentence" notion in protein sequences. However, it is possible that this token still learnt some information as it essentially allowed the model to use it as a sort of "wild-card" that is not constraint by a specific task but could be used to learn/store some global properties. However, this is all speculation as we never investigated in detail what this token learnt. If you find something that would support my hypothesis above, it would be great if you could share it :)

As I do this now always: if you really want to get competitive performance for your downstream task, I would absolutely recommend to use ProtT5 instead of ProtBERT.

gitwangxf commented 2 years ago

Thanks for your immediate reply! I'll keep trying and I'd like to share my results if I make any progress. By the way, it is right to extract [CLS] as the first vector of the embedding output ,which is embedding[seq_num][0], and there is not [CLS] in ProtT5?

mheinzinger commented 2 years ago

There is no 1:1 equivalent in ProtT5, however, there is also a special token appended to the very end of ProtT5 embeddings. So you could also check the information content of this special token. Yes, the very first token of ProtBERT should hold the embedding of the [CLS] token

gitwangxf commented 2 years ago

Got it, I'll try. Greatly appreciate your help!

gitwangxf commented 2 years ago

Hi! I got somewhat better classification results using the special token at the end of ProtT5 embeddings than mean-pooling over the length of the entire protein, which may support your hypothesis mentioned above i think : ) Now I'm trying to fine-tune ProtT5, for choosing the classification head, would you recommend to use exactly the same classification model I used for prediction? (As I'm working on protein-protein interaction, the classification model I used previously has multiple input branches. ) Or a single linear layer would be OK?

mheinzinger commented 2 years ago

Oh wow, that is some good news! - Thanks for sharing, I think this could become useful for many other users, as well :)

I have to admit that it's hard for me to give any guidance without further insights into your exact setup but I still give it a shot: Given that you already made good experience with the current setup I would simply stick to it. My understanding would be that you run ProtT5 twice, once for interaction partner A and once for B and then you feed the two resulting embeddings in some sort of downstream prediction layer (either a single FFN where you concatenated the two embeddings of A and B or a "twin-tower structure" where you first pass each embedding (A and B) through the same FNN (hard parameter sharing) and only concatenate them later on. In any case, for PPI prediction I could imagine that its beneficial to randomly swap A and B during each training iteration to avoid that there is any notion towards the sequential ordering in which you feed the embeddings in.

Good Luck!

gitwangxf commented 2 years ago

Thanks for your constructive advice, and I'll take it! I've got only a few tens of thousands training data, but there is a much larger dataset related to this task, would it be beneficial to do multitasks ProtT5 finetunning according to your experience?

mheinzinger commented 2 years ago

We never tried it but given the multi-task capability of T5 in NLP, I would assume that it should also work in our field. I could imagine that the risk of catastrophic forgetting will just make your training pipeline and especially the sampling a bit harder if you go for multi-task fine-tuning. Feel free to keep us posted on your experience once you got results

gitwangxf commented 2 years ago

Okay, thanks a lot!