nickgkan / butd_detr

Code for the ECCV22 paper "Bottom Up Top Down Detection Transformers for Language Grounding in Images and Point Clouds"
Other
74 stars 11 forks source link

Fixing PyTorch Checkpoint Loading Issue in PointNet++ #23

Closed Hiusam closed 1 year ago

Hiusam commented 1 year ago

Hi, Thank you for your great work on this project. I noticed an issue with loading the pre-trained checkpoint in the code. Currently, the load_state_dict method is being used to load the checkpoint, but it doesn't handle the key 'model' in the gf_detector_l6o256.pth file. Therefore, I suggest modifying the code to use the following:

Change:

self.backbone_net.load_state_dict(torch.load(pointnet_ckpt), strict=False)

to:

self.backbone_net.load_state_dict(torch.load(pointnet_ckpt)['model'], strict=False)

Additionally, the pre-trained weights have the prefix module.backbone_net, so we need to remove it to load the weights properly.

Regarding the strict=False argument, it is currently being used in the code, which means the weights are being randomly initialized instead of using the pre-trained checkpoint. Is that intended?

Best regards.

ayushjain1144 commented 1 year ago

Hi @Hiusam ,

good catch, this is certainly not an intended behaviour. Feel free to make a PR if you like or we will push a fix in some time! Thank you!

Best, Ayush