xiaoman-zhang / KAD

MIT License
122 stars 10 forks source link

Mismatch bert model checkpoint #8

Closed HieuPhan33 closed 1 year ago

HieuPhan33 commented 1 year ago

Hi,

Thanks for sharing the work.

I would like to test zero-shot performance on CheXpert using your provided ckpt.

I'm running the script as python test_chexpert.py --bert_pretrained ../KAD_Models/Knowledge_Encoder/epoch_latest.pt --checkpoint ../KAD_Models/KAD_512/best_valid.pt

I got the error size mismatch for bert_model.embeddings.word_embeddings.weight: copying a param with shape torch.Size([30522, 768]) from checkpoint, the shape in current model is torch.Size([28996, 768]).

As the README didn't mention which bert_model_name to use, I use the default one in test_chexpert.py, which is emilyalsentzer/Bio_ClinicalBERT.

Could you please give me some guidance?

xiaoman-zhang commented 1 year ago

Sorry for the confusion. The bert_model_name used is 'microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext'.

HieuPhan33 commented 1 year ago

Hi, following your suggestion, I have used to command: python test_chexpert.py --bert_pretrained ../KAD_Models/Knowledge_Encoder/epoch_latest.pt --checkpoint ../KAD_Models/KAD_512/best_valid.pt --bert_model_name microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext.

But I got another error: RuntimeError: Error(s) in loading state_dict for CLP_clinical: Unexpected key(s) in state_dict: "bert_model.embeddings.position_ids".

Could you help me identify the issue?

xiaoman-zhang commented 1 year ago

I tried

bert_model_name = 'microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext'
tokenizer = AutoTokenizer.from_pretrained(bert_model_name,do_lower_case=True, local_files_only=True)
text_encoder = CLP_clinical(bert_model_name=bert_model_name).to('cuda') 
bert_pretrained  = './epoch_latest.pt'
if bert_pretrained:
        checkpoint  = torch.load(bert_pretrained, map_location='cpu')
        state_dict = checkpoint["state_dict"]
        text_encoder.load_state_dict(state_dict)
        print('Load pretrained bert success from: ',bert_pretrained)

The output is

Some weights of the model checkpoint at microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.decoder.bias', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext
Text feature extractor: microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext
bert encoder layers: 12
Load pretrained bert success from: ./epoch_latest.pt
xiaoman-zhang commented 1 year ago

Hello, as I trained multiple versions of text encoders (with different initializations). The given checkpoint should be initialized by 'GanjinZero/UMLSBert_ENG', or you can use the model that I uploaded on hugging face 'xmcmic/Med-KEBERT'. Sorry for not checking carefully when pushing the codes.

HieuPhan33 commented 1 year ago

I still have the same problem "Unexpected key(s) in state_dict: "bert_model.embeddings.position_ids" with 'GanjinZero/UMLSBert_ENG'.

I used 'xmcmic/Med-KEBERT', and commented out the bert loading part:

    # if args.bert_pretrained:
    #     checkpoint = torch.load(args.bert_pretrained, map_location='cpu')
    #     state_dict = checkpoint["state_dict"]
    #     text_encoder.load_state_dict(state_dict)
    #     print('Load pretrained bert success from: ',args.bert_pretrained)
    #     if args.freeze_bert:
    #         for param in text_encoder.parameters():
    #             param.requires_grad = False

I then got the error on Ln 87: https://github.com/xiaoman-zhang/KAD/blob/746f0c677e3b465bb17a8e8893dcf761480caa04/A3_CLIP/test_chexpert.py#L87

File "test_chexpert.py", line 87, in main 
    text_encoder.load_state_dict(text_state_dict)  
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1482, in 
load_state_dict                                                                        
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(RuntimeError: Error(s) in loading state_dict for CLP_clinical:                            
        Unexpected key(s) in state_dict: "bert_model.embeddings.position_ids".
xiaoman-zhang commented 1 year ago

It's a bit strange, have you tried it without commenting on the Bert loading part?

HieuPhan33 commented 1 year ago

Without commenting it, I have the same problem "Unexpected key(s) in state_dict: "bert_model.embeddings.position_ids" inside bert_pretrained loading part, with both 'GanjinZero/UMLSBert_ENG' and 'xmcmic/Med-KEBERT'

Command:

CUDA_VISIBLE_DEVICES=1 python test_chexpert.py --bert_pretrained ../KAD_Models/Knowledge_Encoder/epoch_latest.pt --checkpoint ../KAD_Models/KAD_512/best_valid.pt --bert_model_name xmcmic/Med-KEBERT (or GanjinZero/UMLSBert_ENG)

Error:

File "test_chexpert.py", line 76, in main 
    text_encoder.load_state_dict(state_dict)  
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1482, in load_state_dict                                                                           
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for CLP_clinical: 
        Unexpected key(s) in state_dict: "bert_model.embeddings.position_ids".
xiaoman-zhang commented 1 year ago

Maybe you can try this, add strict=False in text_encoder.load_state_dict(state_dict).