CAMMA-public / rendezvous

A transformer-inspired neural network for surgical action triplet recognition from laparoscopic videos.
Other
24 stars 8 forks source link

Model fails to build, when tested on pre-trained weights #11

Closed SimeonAllmendinger closed 1 year ago

SimeonAllmendinger commented 1 year ago

Hi Chinedu,

I am trying to use the „rendezvous_l8_cholect50_crossval_k4_layernorm_lowres.pth“ file as my weights for testing. Unfortunately, the algorithm fails to build the model, when calling: python3 run.py -e --data_dir="/path/to/CholecT45dataset" --dataset_variant=cholect45-crossval --kfold 3 --batch 32 --version=1 --test_ckpt="/path/to/weights" It seems, that the weights in the file have the size (100), whereas the model itself is prebuilt with size (100, 8, 14). So it stops, when running: model.load_state_dict(torch.load(test_ckpt)) I have the same issue with the „rendezvous_l8_cholect50_crossval_k5_layernorm_lowres.pth“ file

Do you have any idea, if there is an easy fix to that problem?

Thank you very much in advance! I have added the error message below:

Missing key(s) in state_dict: "decoder.mhma.0.ln.running_mean", "decoder.mhma.0.ln.running_var", "decoder.mhma.1.ln.running_mean", "decoder.mhma.1.ln.running_var", "decoder.mhma.2.ln.running_mean", "decoder.mhma.2.ln.running_var", "decoder.mhma.3.ln.running_mean", "decoder.mhma.3.ln.running_var", "decoder.mhma.4.ln.running_mean", "decoder.mhma.4.ln.running_var", "decoder.mhma.5.ln.running_mean", "decoder.mhma.5.ln.running_var", "decoder.mhma.6.ln.running_mean", "decoder.mhma.6.ln.running_var", "decoder.mhma.7.ln.running_mean", "decoder.mhma.7.ln.running_var", "decoder.ffnet.0.ln.running_mean", "decoder.ffnet.0.ln.running_var", "decoder.ffnet.1.ln.running_mean", "decoder.ffnet.1.ln.running_var", "decoder.ffnet.2.ln.running_mean", "decoder.ffnet.2.ln.running_var", "decoder.ffnet.3.ln.running_mean", "decoder.ffnet.3.ln.running_var", "decoder.ffnet.4.ln.running_mean", "decoder.ffnet.4.ln.running_var", "decoder.ffnet.5.ln.running_mean", "decoder.ffnet.5.ln.running_var", "decoder.ffnet.6.ln.running_mean", "decoder.ffnet.6.ln.running_var", "decoder.ffnet.7.ln.running_mean", "decoder.ffnet.7.ln.running_var“. size mismatch for decoder.mhma.0.ln.weight: copying a param with shape torch.Size([100, 8, 14]) from checkpoint, the shape in current model is torch.Size([100]). ...

Regards, Simeon

nwoyecid commented 1 year ago

Please, familiarize yourself with the argparse configurations. Every pretrained weight is saved with filename that could describe the minimal configuration (e.g. using layernorm (use_ln) or batchnorm, lowres (for low resolution), cross-val split k, etc.). And the run.py has argparser that allows you to set the right arguments. It seems you are loading a layernorm model weights on batchnorm configuration.