Bruce-ywj / ERNIE-RNA

Code of a structure-enhanced RNA language model named ERNIE-RNA
MIT License
9 stars 1 forks source link

question about fine tuned ss prediction #3

Open FabianGitHub1 opened 4 months ago

FabianGitHub1 commented 4 months ago

Hi, first of all, thank you very much for publishing this code! I have a question about the fine tuned secondary structure prediction. The predict_ss_rna.py script demonstrates how to use the fine tuned ernie-rna to predict the secondary structure of an rna sequence. I applied the same method to all sequences from the TS0 set (1305 element test set from bprna-1m), except converting to dot-bracket-notation in the end. Should it be possible this way to achieve the f1 score of 0.873 that was reported in the paper? When I calculate the f1 score over all these predictions, I get a result of around 0.77. I tried both calculating the f1 score globally over all pairs as well as averaging over the per sequence scores. What do I need to do differently to get the better results? Thanks a lot in advance!

Bruce-ywj commented 3 months ago

@FabianGitHub1 Thank you very much for your question.

The predict_ss_rna.py script provides two types of prediction results: one is the zero-shot prediction result, and the other is the fine-tuned result. The average F1 scores of these two prediction methods on the 1305 test set of bpRNA-1m are shown in the figure below. The macro average F1 score of the fine-tuned model is 0.873 fine-tune The macro average F1 score of the zero-shot model is 0.77. zero-shot May I ask which prediction result you have tested?

FabianGitHub1 commented 3 months ago

Hi @Bruce-ywj , thank you for your reply! I tired both the zero shot and the fine tuned prediction, using the methods form the predict_ss_rna.py script. I calculated the scores using pytorch ignite metrics. From my understanding, this should be the macro f1 score, since it is calculated across all pairs (and not per sequence). I get the following results:

fine tuned: f1 score: 0.776033161663107 precision: 0.7706231289499944 recall: 0.781519691901161

zero shot: f1 score: 0.6030812221936237 precision: 0.5346386196705637 recall: 0.6916200489135018

This is my code for the fine tuned version (zero shot similar):


import torch
from ignite.engine import Engine
from ignite.metrics import Fbeta, Precision, Recall

from ernie_rna.predict_ss_rna import seq_to_rnaindex_and_onehot, post_process
from ernie_rna.src.utils import load_pretrained_ernierna, prepare_input_for_ernierna, ChooseModel

from rna_data.data_utils.sec_struct_data import SecStructData

pretrained_path = '/home/fabian/rna-second-structure-prediction/ernie_rna/checkpoint/ERNIE-RNA_checkpoint/ERNIE-RNA_pretrain.pt'
arg_overrides = {"data": '/home/fabian/rna-second-structure-prediction/ernie_rna/src/dict'}
ss_fine_tuned_path = "/home/fabian/rna-second-structure-prediction/ernie_rna/checkpoint/ERNIE-RNA_ss_prediction_checkpoint/ERNIER-RNA_ss_prediction.pt"

ts0 = SecStructData('TS0', max_len=500)

device = "cuda"

model_pre = load_pretrained_ernierna(pretrained_path, arg_overrides)
my_model = ChooseModel(model_pre.encoder)
state_dict = torch.load(ss_fine_tuned_path, map_location="cpu")
new_state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
my_model.load_state_dict(new_state_dict)
my_model = my_model.to(device)
my_model.eval()

def inference(engine, batch):
    with torch.no_grad():
        seq = batch["seq"]
        contacts = batch["contact"]

        X, data_seq = seq_to_rnaindex_and_onehot(seq)
        one_d, twod_data = prepare_input_for_ernierna(X, len(seq))

        one_d = one_d.to(device)
        twod_data = twod_data.to(device)
        data_seq = data_seq.to(device)

        fine_pred_ss = my_model(one_d, twod_data)

        pair_attn = fine_pred_ss.unsqueeze(0)
        post_pair_attn = post_process(pair_attn, data_seq, 0.01, 0.1, 100, 1.6, True, 1.5)
        fine_tuned_pred = (post_pair_attn > 0.5).float().squeeze().cpu()

    return {"y": contacts.float(), "y_pred": fine_tuned_pred}

evaluation_engine = Engine(inference)

precision = Precision(average=False)
precision.attach(evaluation_engine, "Precision")
recall = Recall(average=False)
recall.attach(evaluation_engine, "Recall")
f1 = Fbeta(beta=1.0, precision=precision, recall=recall)
f1.attach(evaluation_engine, "F1-Score")

state = evaluation_engine.run(iter(ts0), max_epochs=1)

print(f"f1 score: {state.metrics['F1-Score']}")
print(f"precision: {state.metrics['Precision']}")
print(f"recall: {state.metrics['Recall']}")

Of course it depends on external handling of the data, but maybe you can still spot my mistake? Or maybe you could share the evaluation script that produces the output from your message above? Thanks a lot!

FabianGitHub1 commented 3 months ago

Hi @Bruce-ywj ! I think my issue could be related to the prepare_input_for_ernierna function. As far as I understand, its return value two_d should be the tensor that is used as bias for the first attention layer. However, for me it is all zero. I think the reason for that is that it expects its input index to be 1-d but is called with 2-d input. I tried changing that and it does lead to the two_d tensor looking more like I expected. But I still do not get the correct results for the evaluation on the TS0 dataset. So maybe I am misunderstanding it and the prepare_input_for_ernierna function should do something different?

kevin-liuguang commented 2 weeks ago

Hi @FabianGitHub1 ! my two_d is also all zeros. Have you found the reason or fixed it?

FabianGitHub1 commented 1 week ago

Hi @kevin-liuguang ,

I changed the function call in predict from one_d, twod_data = prepare_input_for_ernierna(X, len(seq)) to one_d, twod_data = prepare_input_for_ernierna(X[0], len(seq)). This made the two_d tensor look more like I expected. It did however not solve my original problem of not being able to reproduce the f1 score of 0.87 on the test set.