agemagician / ProtTrans

ProtTrans is providing state of the art pretrained language models for proteins. ProtTrans was trained on thousands of GPUs from Summit and hundreds of Google TPUs using Transformers Models.
Academic Free License v3.0
1.1k stars 152 forks source link

Generating embedding from finetuned model #146

Open abelavit opened 6 months ago

abelavit commented 6 months ago

Hello,

I needed help on how to go about generating embedding after ProtT5 has been finetuned. I have carried out finetuning of the model using the sample code 'PT5_LoRA_Finetuning_per_residue_class.ipynb' on my own dataset. I have the saved mode called PT5_secstr_finetuned.pth. How do we now extract embedding for new protein sequences such as sequence_examples = ["PRTEINO", "SEQWENCE"] using the finetuned model?

Thank you for your time.

mheinzinger commented 6 months ago

The model outputs have a field called hidden_states which contain the embeddings. Sth along those lines: embeddings = model(input_ids, attention_mask=attention_mask).hidden_states

abelavit commented 6 months ago

For loading the original pre-trained model, such as ProtT5, it can be done so:

Load the tokenizer tokenizer = T5Tokenizer.from_pretrained('Rostlab/prot_t5_xl_half_uniref50-enc', do_lower_case=False)

Load the model model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_half_uniref50-enc").to(device)

To load the finetuned model, from the PT5_LoRA_Finetuning_per_residue_class.ipynb script, the command seems to be:

tokenizer, model_reload = load_model("./PT5_secstr_finetuned.pth", num_labels=3, mixed = False)

The load_model in the above calls other functions (e.g. PT5_classification_model function) which leads to having a chunky script. I am wondering if there was a simple way to load the finetuned model and obtain embedding for protein sequences, such as done for the original pre-trained model (ProtT5).

I am not sure if I am doing it right.

Thanks.

mheinzinger commented 5 months ago

I see your point; however, currently we do not have the bandwidth to work on a nicer interface, sorry. In case you should find a nicer way, e.g., by using https://github.com/huggingface/peft , feel free to share or to create a pull request :)