aehrc / cxrmate

CXRMate: Longitudinal Data and a Semantic Similarity Reward for Chest X-Ray Report Generation
https://huggingface.co/aehrc/cxrmate
Apache License 2.0
14 stars 3 forks source link

GTPrompt model training problem #16

Open yihp opened 1 month ago

yihp commented 1 month ago

Hi! Thanks for your contribution. It is an excellent piece of work!

My task language is Chinese. I have trained the MultiCXR model on my own vocabulary, I have the following problems when training the GTPrompt model:

I cannot load the multi_ckpt_name: aehrc/cxrmate-multi-tf you trained, because the word embedding dimension size is different, and the cxrmate-multi-tf-cn I trained myself did not save the model file in the pytorch_model.bin format, so I don’t know how to load it.

How should I load the trained MultiCXR model in.ckpt format.

# Load multi checkpoint:
if encoder_decoder_ckpt_name:
    encoder_decoder = AutoModel.from_pretrained(encoder_decoder_ckpt_name, trust_remote_code=True)
    self.load_state_dict(encoder_decoder.state_dict())
else:
    warnings.warn('The encoder-to-decoder model was not warm-started before applying low-rank approximation.')
anicolson commented 1 month ago

Hi @yihp,

If test_ckpt_name is in your config, it will use the hugging face from_pretrained method during testing.

Else, it will automatically load the .ckpt as the lightning module from your exp_dir: https://github.com/aehrc/cxrmate/blob/820607a5511b9cb4131b09713c32655e7d9cbb03/tools/stages.py#L103

and

https://github.com/aehrc/cxrmate/blob/820607a5511b9cb4131b09713c32655e7d9cbb03/tools/stages.py#L110

Hence, remove test_ckpt_name from your config to test the .ckpt file.

yihp commented 1 month ago

Hi @anicolson ,

Thank you very much for your reply.

Are the network structures of the GTPrompt model and the MultiCXR modelthe same? So during training, can I load the model checkpoint ckpt of the MultiCXR model when training the GTPrompt model?

Looking forward to your reply !

anicolson commented 1 month ago

Hi @yihp,

Yes, the MultiCXR model is used to warm-start GTPrompt.

yihp commented 1 month ago

Hi @anicolson ,

But I don't know if there is a problem with the pytorch_model.bin I saved when training the MultiCXR model, which causes garbled output during the verification process.

So can I specify the last.ckpt of the MultiCXR model?

anicolson commented 1 month ago

Hi @yihp,

Specify warm_start_ckpt_path in your config:

https://github.com/aehrc/cxrmate/blob/820607a5511b9cb4131b09713c32655e7d9cbb03/tools/stages.py#L53

yihp commented 1 month ago

OK, I will go to the lab to try it later. Thank you very much for your reply !

yihp commented 1 month ago

Hi @anicolson ,

I specify warm_start_ckpt_path for training: dlhpcstarter -t cxrmate -c config/train/longitudinal_gt_prompt_tf_qwen.yaml --stages_module tools.stages --train --trial 5 --warm-start-ckpt-path experiments/cxrmate/multi_tf/trial_0/epoch=3-step=7840-val_report_nlg_bleu_4=0.017195.ckpt But the following error occurred:

Traceback (most recent call last):
  File "/home/maiyue/anaconda3/envs/cxrmate/bin/dlhpcstarter", line 8, in <module>
    sys.exit(main())
  File "/home/maiyue/anaconda3/envs/cxrmate/lib/python3.8/site-packages/dlhpcstarter/__main__.py", line 126, in main
    submit(args, cmd_line_args, stages_fnc)
  File "/home/maiyue/anaconda3/envs/cxrmate/lib/python3.8/site-packages/dlhpcstarter/__main__.py", line 21, in submit
    stages_fnc(args)
  File "/public-data/yhp/cxrmate/tools/stages.py", line 49, in stages
    model = TaskModel.load_from_checkpoint(checkpoint_path=args.warm_start_ckpt_path, **vars(args))
  File "/home/maiyue/anaconda3/envs/cxrmate/lib/python3.8/site-packages/lightning/pytorch/utilities/model_helpers.py", line 125, in wrapper
    return self.method(cls, *args, **kwargs)
  File "/home/maiyue/anaconda3/envs/cxrmate/lib/python3.8/site-packages/lightning/pytorch/core/module.py", line 1586, in load_from_checkpoint
    loaded = _load_from_checkpoint(
  File "/home/maiyue/anaconda3/envs/cxrmate/lib/python3.8/site-packages/lightning/pytorch/core/saving.py", line 91, in _load_from_checkpoint
    model = _load_state(cls, checkpoint, strict=strict, **kwargs)
  File "/home/maiyue/anaconda3/envs/cxrmate/lib/python3.8/site-packages/lightning/pytorch/core/saving.py", line 187, in _load_state
    keys = obj.load_state_dict(checkpoint["state_dict"], strict=strict)
  File "/home/maiyue/anaconda3/envs/cxrmate/lib/python3.8/site-packages/torch/nn/modules/module.py", line 2215, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for GTPrompt:
        Missing key(s) in state_dict: "encoder_decoder.decoder.base_model.model.bert.embeddings.word_embeddings.weight", "encoder_decoder.decoder.base_model.model.bert.embeddings.position_embeddings.weight", "encoder_decoder.decoder.base_model.model.bert.embeddings.token_type_embeddings.weight", "encoder_decoder.decoder.base_model.model.bert.embeddings.LayerNorm.weight", "encoder_decoder.decoder.base_model.model.bert.embeddings.LayerNorm.bias", "encoder_decoder.decoder.base_model.model.bert.encoder.layer.0.attention.self.query.base_layer.weight", "encoder_decoder.decoder.base_model.model.bert.encoder.layer.0.attention.self.query.base_layer.bias", "encoder_decoder.decoder.base_model.model.bert.encoder.layer.0.attention.self.query.lora_A.default.weight", "encoder_decoder.decoder.base_model.model.bert.encoder.layer.0.attention.self.query.lora_B.default.weight", "encoder_decoder.decoder.base_model.model.bert.encoder.layer.0.attention.self.key.base_layer.weight", "encoder_decoder.decoder.base_model.model.bert.encoder.layer.0.attention.self.key.base_layer.bias", "encoder_decoder.decoder.base_model.model.bert.encoder.layer.0.attention.self.key.lora_A.default.weight", "encoder_decoder.decoder.base_model.model.bert.encoder.layer.0.attention.self.key.lora_B.default.weight", "encoder_decoder.decoder.base_model.model.bert.encoder.layer.0.attention.self.value.weight", "encoder_decoder.decoder.base_model.model.bert.encoder.layer.0.attention.self.value.bias", "encoder_decoder.decoder.base_model.model.bert.encoder.layer.0.attention.output.dense.weight", "encoder_decoder.decoder.base_model.model.bert.encoder.layer.0.attention.output.dense.bias",

Are the network structures of the GTPrompt model and the MultiCXR model the same? Why don't they match?

Looking forward to your reply !

anicolson commented 1 month ago

Hi @yihp,

Ah, I am sorry, I forgot about LoRA. So GTPrompt is MultiCXR + LoRA. So the model is warm started and then LoRA is added. You can see this here:

https://github.com/aehrc/cxrmate/blob/820607a5511b9cb4131b09713c32655e7d9cbb03/modules/lightning_modules/longitudinal/gt_prompt.py#L62.

and here:

https://github.com/aehrc/cxrmate/blob/820607a5511b9cb4131b09713c32655e7d9cbb03/modules/transformers/longitudinal_model/modelling_longitudinal.py#L151

So all this has to happen within the class due to the differences.

This is a bit annoying, but you have to save the .ckpt as a Hugging Face model checkpoint: https://github.com/aehrc/cxrmate/blob/main/modules/transformers/multi_tf_model_to_hub.ipynb

And instead of setting warm_start_ckpt_path, set multi_ckpt_name in your config. And multi_ckpt_name should be the save_path from that notebook.

Sorry for the confusion.

anicolson commented 1 month ago

Hi, I see you removed your comment, are you still interested in this?

From: yihp @.> Date: Thursday, 19 September 2024 at 7:08 pm To: aehrc/cxrmate @.> Cc: Nicolson, Aaron (H&B, Herston) @.>, Mention @.> Subject: Re: [aehrc/cxrmate] GTPrompt model training problem (Issue #16)

Hi @anicolsonhttps://github.com/anicolson ,

I have another question about how to save the aehrc/cxrmate-tf Hugging Face model checkpoint? Is aehrc/cxrmate-tf the LongitudinalPromptMultiCXREncoderDecoderModel model class?Am I converting it in the following way:

Encoder & decoder config:

config_decoder = transformers.BertConfig(

vocab_size=151659,

num_hidden_layers=6,

type_vocab_size=2,

) # BERT as it includes token_type_ids.

encoder_ckpt_name = 'microsoft/cvt-21-384-22k'

config_encoder = CvtWithProjectionHeadConfig.from_pretrained(

'/public-data/yhp/cxrmate/microsoft/cvt-21-384-22k',

# os.path.join(ckpt_zoo_dir, encoder_ckpt_name),

local_files_only=True,

projection_size=config_decoder.hidden_size,

)

config = transformers.VisionEncoderDecoderConfig.from_encoder_decoder_configs(config_encoder, config_decoder)

Encoder-to-decoder instance:

LongitudinalPromptMultiCXREncoderDecoderModel.register_for_auto_class("AutoModel")

encoder_decoder = LongitudinalPromptMultiCXREncoderDecoderModel(config=config)

for key in list(state_dict.keys()):

if 'encoder_projection' in key:

    state_dict[key.replace('encoder_projection', 'encoder.projection_head.projection')] = state_dict.pop(key)

elif 'last_hidden_state_layer_norm' in key:

    state_dict[key.replace('last_hidden_state_layer_norm', 'encoder.projection_head.layer_norm')] = state_dict.pop(key)

elif 'encoder.encoder' in key:

    state_dict[key.replace('encoder.encoder', 'encoder.cvt.encoder')] = state_dict.pop(key)

elif 'encoder_decoder.' in key:

    state_dict[key.replace('encoder_decoder.', '')] = state_dict.pop(key)

else:

    warnings.warn(f'Key not found: {key}')

encoder_decoder.load_state_dict(state_dict)

encoder_decoder.save_pretrained(save_path)

I converted it like this and trained scst model use config public-longitudinal_gt_prompt_cxr-bert.yaml, but the model output was garbled Looking forward to your reply !

— Reply to this email directly, view it on GitHubhttps://github.com/aehrc/cxrmate/issues/16#issuecomment-2360438297, or unsubscribehttps://github.com/notifications/unsubscribe-auth/AGHGZ7RWIFAYSTYRCSGTID3ZXKH67AVCNFSM6AAAAABOEUCMD2VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDGNRQGQZTQMRZG4. You are receiving this because you were mentioned.Message ID: @.***>

yihp commented 1 month ago

Hi @anicolson ,

Thank you very much for your reply.

I have two question:

  1. the first about how to save the aehrc/cxrmate-tf Hugging Face model checkpoint?
  2. Secondly,The paper states that for SCST, validation was performed every 1/10 of an epoch.How should it be set? Every_n_epochs: 0.1 did not work.

Looking forward to your reply !

anicolson commented 1 month ago

With 1), I've added the remaining notebooks to save to hf checkpoints here: https://github.com/aehrc/cxrmate/tree/main/modules/transformers. It can be a bit of a nightmare with getting the key names right in the state_dict, so you might have to play around with that.

With 2), this has been added back into the configs: https://github.com/aehrc/cxrmate/blob/b106927021e7037e4198bdc1dd36524c227303c8/config/train/longitudinal_gt_prompt_cxr-bert.yaml#L17.

yihp commented 1 month ago

Hi @anicolson ,

Thank you very much for your reply ! ! ! I have two question:

Firstly, If I want to use bert_score as a reward, do you have any related experiments? Do I just need to change ckpt_name = 'microsoft/BiomedVLP-CXR-BERT-specialized' to ckpt_name = 'microsoft/bert-base-chinese',do I use the output of the last layer of cls as the word embedding vector to calculate the cosine similarity? https://github.com/aehrc/cxrmate/blob/b106927021e7037e4198bdc1dd36524c227303c8/tools/rewards/cxrbert.py#L15

Secondly, during the training of different models(single_tf, multi_tf, longitudinal_gt_prompt_tf, longitudinal_gt_prompt_cxr-bert.yaml), how did you set the following training parameters:

devices: 
max_epochs: 
mbatch_size: 
accumulated_mbatch_size: 

Looking forward to your reply !

anicolson commented 1 month ago

Hi @yihp,

See https://github.com/aehrc/cxrmate-ed/blob/main/tools/rewards/bertscore.py

And https://github.com/aehrc/cxrmate-ed/blob/17bb8f1131f58c151ccb7b46667ed5a98e79e660/modules/lightning_modules/cxrmate_ed/scst_rewards.py#L9

Note that the cxrmate-ed repo will be heavily refactored in a couple of weeks.

So I was using 4xP100 GPUs to train the model.

For single_tf and multi_tf:

devices: 4
max_epochs: 32
mbatch_size: 8
accumulated_mbatch_size: 32

For longitudinal_gt_prompt_tf:

devices: 4
max_epochs: 32
mbatch_size: 2
accumulated_mbatch_size: 32

For longitudinal_gt_prompt_cxr-bert:

devices: 4
max_epochs: 32
mbatch_size: 1   # See paper for explanation of this.
accumulated_mbatch_size: 32