mana438 / RNABERT

46 stars 14 forks source link

get embeddings error #4

Open pillowill opened 5 months ago

pillowill commented 5 months ago

dear author: when i use the command :"python MLM_SFP.py --pretraining bert_mul_2.pth --data_embedding my_rna.fa --embedding_output rRNABert_emb.csv --batch 40" i met the following errors: RuntimeError: Error(s) in loading state_dict for BertForMaskedLM: Missing key(s) in state_dict: "bert.embeddings.word_embeddings.weight", "bert.embeddings.position_embeddings.weight", "bert.embeddings.token_type_embeddings.weight", "bert.embeddings.LayerNorm.gamma", "bert.embeddings.LayerNorm.beta", "bert.encoder.layer.0.attention.selfattn.query.weight", "bert.encoder.layer.0.attention.selfattn.query.bias", "bert.encoder.layer.0.attention.selfa...

sunyunlee commented 1 week ago

Hi I am getting the same error message when trying to extract embeddings from the pre trained model without fine tuning it. I’m assuming it has to do with the discrepancy between the initialized model and the existing weights. Has the issue been addressed/fixed? Thanks in advance.

sunyunlee commented 1 week ago

I was able to figure out the issue. The issue is that the OrderedDict in the pretrained file has a different parameter names than the ones the Bert class object was expecting. It has an additional word.

import torch
from collections import OrderedDict

file_path = 'bert_mul_2.pth'

state_dict = torch.load(file_path, map_location="cpu")

new_state_dict = OrderedDict()

for key, value in state_dict.items():
    # Modify the key as needed
    new_key = ".".join(key.split(".")[1:])
    new_state_dict[new_key] = value.clone()

torch.save(new_state_dict, 'bert_mul_2_correction.pth')

for key in new_state_dict.keys():
    print(key)

ran this first to get a new weight file.