PRBonn / MaskPLS

Mask-Based Panoptic LiDAR Segmentation for Autonomous Driving, RA-L, 2023
MIT License
54 stars 7 forks source link

Load MinkUNet Pre-trained Weights #11

Closed yuyang-cloud closed 1 year ago

yuyang-cloud commented 1 year ago
          > Hi! We leveraged the weights of the feature extractor provided [here](https://github.com/mit-han-lab/spvnas).

Hi! @rmarcuzzi @comradexy I downloaded the backbone weights at SemanticKITTI_val_MinkUNet@114GMACs and used the code below to load the pre-trained weights:

state_dict =  torch.load('../ckpts/backbone.pth', map_location='cpu')
self.backbone.load_state_dict(state_dict )

But there's a Missing keys and Unexpected keys problem:

*** RuntimeError: Error(s) in loading state_dict for MinkEncoderDecoder:
        Missing key(s) in state_dict: "stem.0.kernel", "stem.1.bn.weight", "stem.1.bn.bias", "stem.1.bn.running_mean", "stem.1.bn.running_var", "stem.3.kernel", "stem.4.bn.weight", "stem.4.bn.bias", "stem.4.bn.running_mean", "stem.4.bn.running_var", "stage1.0.net.1.bn.weight", "stage1.0.net.1.bn.bias", "stage1.0.net.1.bn.running_mean", "stage1.0.net.1.bn.running_var", "stage1.1.net.1.bn.weight", "stage1.1.net.1.bn.bias", "stage1.1.net.1.bn.running_mean", "stage1.1.net.1.bn.running_var", "stage1.1.net.4.bn.weight", "stage1.1.net.4.bn.bias", "stage1.1.net.4.bn.running_mean", "stage1.1.net.4.bn.running_var", "stage1.2.net.1.bn.weight", "stage1.2.net.1.bn.bias", "stage1.2.net.1.bn.running_mean", "stage1.2.net.1.bn.running_var", "stage1.2.net.4.bn.weight", "stage1.2.net.4.bn.bias", "stage1.2.net.4.bn.running_mean", "stage1.2.net.4.bn.running_var", "stage2.0.net.1.bn.weight", "stage2.0.net.1.bn.bias", "stage2.0.net.1.bn.running_mean", "stage2.0.net.1.bn.running_var", "stage2.1.net.1.bn.weight", "stage2.1.net.1.bn.bias", "stage2.1.net.1.bn.running_mean", "stage2.1.net.1.bn.running_var", "stage2.1.net.4.bn.weight", "stage2.1.net.4.bn.bias", "stage2.1.net.4.bn.running_mean", "stage2.1.net.4.bn.running_var", "stage2.1.downsample.1.bn.weight", "stage2.1.downsample.1.bn.bias", "stage2.1.downsample.1.bn.running_mean", "stage2.1.downsample.1.bn.running_var", "stage2.2.net.1.bn.weight", "stage2.2.net.1.bn.bias", "stage2.2.net.1.bn.running_mean", "stage2.2.net.1.bn.running_var", "stage2.2.net.4.bn.weight", "stage2.2.net.4.bn.bias", "stage2.2.net.4.bn.running_mean", "stage2.2.net.4.bn.running_var", "stage3.0.net.1.bn.weight", "stage3.0.net.1.bn.bias", "stage3.0.net.1.bn.running_mean", "stage3.0.net.1.bn.running_var", "stage3.1.net.1.bn.weight", "stage3.1.net.1.bn.bias", "stage3.1.net.1.bn.running_mean", "stage3.1.net.1.bn.running_var", "stage3.1.net.4.bn.weight", "stage3.1.net.4.bn.bias", "stage3.1.net.4.bn.running_mean", "stage3.1.net.4.bn.running_var", "stage3.1.downsample.1.bn.weight", "stage3.1.downsample.1.bn.bias", "stage3.1.downsample.1.bn.running_mean", "stage3.1.downsample.1.bn.running_var", "stage3.2.net.1.bn.weight", "stage3.2.net.1.bn.bias", "stage3.2.net.1.bn.running_mean", "stage3.2.net.1.bn.running_var", "stage3.2.net.4.bn.weight", "stage3.2.net.4.bn.bias", "stage3.2.net.4.bn.running_mean", "stage3.2.net.4.bn.running_var", "stage4.0.net.1.bn.weight", "stage4.0.net.1.bn.bias", "stage4.0.net.1.bn.running_mean", "stage4.0.net.1.bn.running_var", "stage4.1.net.1.bn.weight", "stage4.1.net.1.bn.bias", "stage4.1.net.1.bn.running_mean", "stage4.1.net.1.bn.running_var", "stage4.1.net.4.bn.weight", "stage4.1.net.4.bn.bias", "stage4.1.net.4.bn.running_mean", "stage4.1.net.4.bn.running_var", "stage4.1.downsample.1.bn.weight", "stage4.1.downsample.1.bn.bias", "stage4.1.downsample.1.bn.running_mean", "stage4.1.downsample.1.bn.running_var", "stage4.2.net.1.bn.weight", "stage4.2.net.1.bn.bias", "stage4.2.net.1.bn.running_mean", "stage4.2.net.1.bn.running_var", "stage4.2.net.4.bn.weight", "stage4.2.net.4.bn.bias", "stage4.2.net.4.bn.running_mean", "stage4.2.net.4.bn.running_var", "up1.0.net.1.bn.weight", "up1.0.net.1.bn.bias", "up1.0.net.1.bn.running_mean", "up1.0.net.1.bn.running_var", "up1.1.0.net.1.bn.weight", "up1.1.0.net.1.bn.bias", "up1.1.0.net.1.bn.running_mean", "up1.1.0.net.1.bn.running_var", "up1.1.0.net.4.bn.weight", "up1.1.0.net.4.bn.bias", "up1.1.0.net.4.bn.running_mean", "up1.1.0.net.4.bn.running_var", "up1.1.0.downsample.1.bn.weight", "up1.1.0.downsample.1.bn.bias", "up1.1.0.downsample.1.bn.running_mean", "up1.1.0.downsample.1.bn.running_var", "up1.1.1.net.1.bn.weight", "up1.1.1.net.1.bn.bias", "up1.1.1.net.1.bn.running_mean", "up1.1.1.net.1.bn.running_var", "up1.1.1.net.4.bn.weight", "up1.1.1.net.4.bn.bias", "up1.1.1.net.4.bn.running_mean", "up1.1.1.net.4.bn.running_var", "up2.0.net.1.bn.weight", "up2.0.net.1.bn.bias", "up2.0.net.1.bn.running_mean", "up2.0.net.1.bn.running_var", "up2.1.0.net.1.bn.weight", "up2.1.0.net.1.bn.bias", "up2.1.0.net.1.bn.running_mean", "up2.1.0.net.1.bn.running_var", "up2.1.0.net.4.bn.weight", "up2.1.0.net.4.bn.bias", "up2.1.0.net.4.bn.running_mean", "up2.1.0.net.4.bn.running_var", "up2.1.0.downsample.1.bn.weight", "up2.1.0.downsample.1.bn.bias", "up2.1.0.downsample.1.bn.running_mean", "up2.1.0.downsample.1.bn.running_var", "up2.1.1.net.1.bn.weight", "up2.1.1.net.1.bn.bias", "up2.1.1.net.1.bn.running_mean", "up2.1.1.net.1.bn.running_var", "up2.1.1.net.4.bn.weight", "up2.1.1.net.4.bn.bias", "up2.1.1.net.4.bn.running_mean", "up2.1.1.net.4.bn.running_var", "up3.0.net.1.bn.weight", "up3.0.net.1.bn.bias", "up3.0.net.1.bn.running_mean", "up3.0.net.1.bn.running_var", "up3.1.0.net.1.bn.weight", "up3.1.0.net.1.bn.bias", "up3.1.0.net.1.bn.running_mean", "up3.1.0.net.1.bn.running_var", "up3.1.0.net.4.bn.weight", "up3.1.0.net.4.bn.bias", "up3.1.0.net.4.bn.running_mean", "up3.1.0.net.4.bn.running_var", "up3.1.0.downsample.1.bn.weight", "up3.1.0.downsample.1.bn.bias", "up3.1.0.downsample.1.bn.running_mean", "up3.1.0.downsample.1.bn.running_var", "up3.1.1.net.1.bn.weight", "up3.1.1.net.1.bn.bias", "up3.1.1.net.1.bn.running_mean", "up3.1.1.net.1.bn.running_var", "up3.1.1.net.4.bn.weight", "up3.1.1.net.4.bn.bias", "up3.1.1.net.4.bn.running_mean", "up3.1.1.net.4.bn.running_var", "up4.0.net.1.bn.weight", "up4.0.net.1.bn.bias", "up4.0.net.1.bn.running_mean", "up4.0.net.1.bn.running_var", "up4.1.0.net.1.bn.weight", "up4.1.0.net.1.bn.bias", "up4.1.0.net.1.bn.running_mean", "up4.1.0.net.1.bn.running_var", "up4.1.0.net.4.bn.weight", "up4.1.0.net.4.bn.bias", "up4.1.0.net.4.bn.running_mean", "up4.1.0.net.4.bn.running_var", "up4.1.0.downsample.1.bn.weight", "up4.1.0.downsample.1.bn.bias", "up4.1.0.downsample.1.bn.running_mean", "up4.1.0.downsample.1.bn.running_var", "up4.1.1.net.1.bn.weight", "up4.1.1.net.1.bn.bias", "up4.1.1.net.1.bn.running_mean", "up4.1.1.net.1.bn.running_var", "up4.1.1.net.4.bn.weight", "up4.1.1.net.4.bn.bias", "up4.1.1.net.4.bn.running_mean", "up4.1.1.net.4.bn.running_var", "sem_head.weight", "sem_head.bias", "out_bnorm.0.weight", "out_bnorm.0.bias", "out_bnorm.0.running_mean", "out_bnorm.0.running_var", "out_bnorm.1.weight", "out_bnorm.1.bias", "out_bnorm.1.running_mean", "out_bnorm.1.running_var", "out_bnorm.2.weight", "out_bnorm.2.bias", "out_bnorm.2.running_mean", "out_bnorm.2.running_var", "out_bnorm.3.weight", "out_bnorm.3.bias", "out_bnorm.3.running_mean", "out_bnorm.3.running_var".
        Unexpected key(s) in state_dict: "stage1.0.net.1.weight", "stage1.0.net.1.bias", "stage1.0.net.1.running_mean", "stage1.0.net.1.running_var", "stage1.0.net.1.num_batches_tracked", "stage1.1.net.1.weight", "stage1.1.net.1.bias", "stage1.1.net.1.running_mean", "stage1.1.net.1.running_var", "stage1.1.net.1.num_batches_tracked", "stage1.1.net.4.weight", "stage1.1.net.4.bias", "stage1.1.net.4.running_mean", "stage1.1.net.4.running_var", "stage1.1.net.4.num_batches_tracked", "stage1.2.net.1.weight", "stage1.2.net.1.bias", "stage1.2.net.1.running_mean", "stage1.2.net.1.running_var", "stage1.2.net.1.num_batches_tracked", "stage1.2.net.4.weight", "stage1.2.net.4.bias", "stage1.2.net.4.running_mean", "stage1.2.net.4.running_var", "stage1.2.net.4.num_batches_tracked", "stage2.0.net.1.weight", "stage2.0.net.1.bias", "stage2.0.net.1.running_mean", "stage2.0.net.1.running_var", "stage2.0.net.1.num_batches_tracked", "stage2.1.net.1.weight", "stage2.1.net.1.bias", "stage2.1.net.1.running_mean", "stage2.1.net.1.running_var", "stage2.1.net.1.num_batches_tracked", "stage2.1.net.4.weight", "stage2.1.net.4.bias", "stage2.1.net.4.running_mean", "stage2.1.net.4.running_var", "stage2.1.net.4.num_batches_tracked", "stage2.1.downsample.1.weight", "stage2.1.downsample.1.bias", "stage2.1.downsample.1.running_mean", "stage2.1.downsample.1.running_var", "stage2.1.downsample.1.num_batches_tracked", "stage2.2.net.1.weight", "stage2.2.net.1.bias", "stage2.2.net.1.running_mean", "stage2.2.net.1.running_var", "stage2.2.net.1.num_batches_tracked", "stage2.2.net.4.weight", "stage2.2.net.4.bias", "stage2.2.net.4.running_mean", "stage2.2.net.4.running_var", "stage2.2.net.4.num_batches_tracked", "stage3.0.net.1.weight", "stage3.0.net.1.bias", "stage3.0.net.1.running_mean", "stage3.0.net.1.running_var", "stage3.0.net.1.num_batches_tracked", "stage3.1.net.1.weight", "stage3.1.net.1.bias", "stage3.1.net.1.running_mean", "stage3.1.net.1.running_var", "stage3.1.net.1.num_batches_tracked", "stage3.1.net.4.weight", "stage3.1.net.4.bias", "stage3.1.net.4.running_mean", "stage3.1.net.4.running_var", "stage3.1.net.4.num_batches_tracked", "stage3.1.downsample.1.weight", "stage3.1.downsample.1.bias", "stage3.1.downsample.1.running_mean", "stage3.1.downsample.1.running_var", "stage3.1.downsample.1.num_batches_tracked", "stage3.2.net.1.weight", "stage3.2.net.1.bias", "stage3.2.net.1.running_mean", "stage3.2.net.1.running_var", "stage3.2.net.1.num_batches_tracked", "stage3.2.net.4.weight", "stage3.2.net.4.bias", "stage3.2.net.4.running_mean", "stage3.2.net.4.running_var", "stage3.2.net.4.num_batches_tracked", "stage4.0.net.1.weight", "stage4.0.net.1.bias", "stage4.0.net.1.running_mean", "stage4.0.net.1.running_var", "stage4.0.net.1.num_batches_tracked", "stage4.1.net.1.weight", "stage4.1.net.1.bias", "stage4.1.net.1.running_mean", "stage4.1.net.1.running_var", "stage4.1.net.1.num_batches_tracked", "stage4.1.net.4.weight", "stage4.1.net.4.bias", "stage4.1.net.4.running_mean", "stage4.1.net.4.running_var", "stage4.1.net.4.num_batches_tracked", "stage4.1.downsample.1.weight", "stage4.1.downsample.1.bias", "stage4.1.downsample.1.running_mean", "stage4.1.downsample.1.running_var", "stage4.1.downsample.1.num_batches_tracked", "stage4.2.net.1.weight", "stage4.2.net.1.bias", "stage4.2.net.1.running_mean", "stage4.2.net.1.running_var", "stage4.2.net.1.num_batches_tracked", "stage4.2.net.4.weight", "stage4.2.net.4.bias", "stage4.2.net.4.running_mean", "stage4.2.net.4.running_var", "stage4.2.net.4.num_batches_tracked", "up1.0.net.1.weight", "up1.0.net.1.bias", "up1.0.net.1.running_mean", "up1.0.net.1.running_var", "up1.0.net.1.num_batches_tracked", "up1.1.0.net.1.weight", "up1.1.0.net.1.bias", "up1.1.0.net.1.running_mean", "up1.1.0.net.1.running_var", "up1.1.0.net.1.num_batches_tracked", "up1.1.0.net.4.weight", "up1.1.0.net.4.bias", "up1.1.0.net.4.running_mean", "up1.1.0.net.4.running_var", "up1.1.0.net.4.num_batches_tracked", "up1.1.0.downsample.1.weight", "up1.1.0.downsample.1.bias", "up1.1.0.downsample.1.running_mean", "up1.1.0.downsample.1.running_var", "up1.1.0.downsample.1.num_batches_tracked", "up1.1.1.net.1.weight", "up1.1.1.net.1.bias", "up1.1.1.net.1.running_mean", "up1.1.1.net.1.running_var", "up1.1.1.net.1.num_batches_tracked", "up1.1.1.net.4.weight", "up1.1.1.net.4.bias", "up1.1.1.net.4.running_mean", "up1.1.1.net.4.running_var", "up1.1.1.net.4.num_batches_tracked", "up2.0.net.1.weight", "up2.0.net.1.bias", "up2.0.net.1.running_mean", "up2.0.net.1.running_var", "up2.0.net.1.num_batches_tracked", "up2.1.0.net.1.weight", "up2.1.0.net.1.bias", "up2.1.0.net.1.running_mean", "up2.1.0.net.1.running_var", "up2.1.0.net.1.num_batches_tracked", "up2.1.0.net.4.weight", "up2.1.0.net.4.bias", "up2.1.0.net.4.running_mean", "up2.1.0.net.4.running_var", "up2.1.0.net.4.num_batches_tracked", "up2.1.0.downsample.1.weight", "up2.1.0.downsample.1.bias", "up2.1.0.downsample.1.running_mean", "up2.1.0.downsample.1.running_var", "up2.1.0.downsample.1.num_batches_tracked", "up2.1.1.net.1.weight", "up2.1.1.net.1.bias", "up2.1.1.net.1.running_mean", "up2.1.1.net.1.running_var", "up2.1.1.net.1.num_batches_tracked", "up2.1.1.net.4.weight", "up2.1.1.net.4.bias", "up2.1.1.net.4.running_mean", "up2.1.1.net.4.running_var", "up2.1.1.net.4.num_batches_tracked", "up3.0.net.1.weight", "up3.0.net.1.bias", "up3.0.net.1.running_mean", "up3.0.net.1.running_var", "up3.0.net.1.num_batches_tracked", "up3.1.0.net.1.weight", "up3.1.0.net.1.bias", "up3.1.0.net.1.running_mean", "up3.1.0.net.1.running_var", "up3.1.0.net.1.num_batches_tracked", "up3.1.0.net.4.weight", "up3.1.0.net.4.bias", "up3.1.0.net.4.running_mean", "up3.1.0.net.4.running_var", "up3.1.0.net.4.num_batches_tracked", "up3.1.0.downsample.1.weight", "up3.1.0.downsample.1.bias", "up3.1.0.downsample.1.running_mean", "up3.1.0.downsample.1.running_var", "up3.1.0.downsample.1.num_batches_tracked", "up3.1.1.net.1.weight", "up3.1.1.net.1.bias", "up3.1.1.net.1.running_mean", "up3.1.1.net.1.running_var", "up3.1.1.net.1.num_batches_tracked", "up3.1.1.net.4.weight", "up3.1.1.net.4.bias", "up3.1.1.net.4.running_mean", "up3.1.1.net.4.running_var", "up3.1.1.net.4.num_batches_tracked", "up4.0.net.1.weight", "up4.0.net.1.bias", "up4.0.net.1.running_mean", "up4.0.net.1.running_var", "up4.0.net.1.num_batches_tracked", "up4.1.0.net.1.weight", "up4.1.0.net.1.bias", "up4.1.0.net.1.running_mean", "up4.1.0.net.1.running_var", "up4.1.0.net.1.num_batches_tracked", "up4.1.0.net.4.weight", "up4.1.0.net.4.bias", "up4.1.0.net.4.running_mean", "up4.1.0.net.4.running_var", "up4.1.0.net.4.num_batches_tracked", "up4.1.0.downsample.1.weight", "up4.1.0.downsample.1.bias", "up4.1.0.downsample.1.running_mean", "up4.1.0.downsample.1.running_var", "up4.1.0.downsample.1.num_batches_tracked", "up4.1.1.net.1.weight", "up4.1.1.net.1.bias", "up4.1.1.net.1.running_mean", "up4.1.1.net.1.running_var", "up4.1.1.net.1.num_batches_tracked", "up4.1.1.net.4.weight", "up4.1.1.net.4.bias", "up4.1.1.net.4.running_mean", "up4.1.1.net.4.running_var", "up4.1.1.net.4.num_batches_tracked".

How do you load pre-trained weights and are the backbone weights frozen during training MaskPLS? Is it convenient to provide the relevant code so that we can reproduce the results? Thanks in advance!

Originally posted by @yuyang-cloud in https://github.com/PRBonn/MaskPLS/issues/7#issuecomment-1652813073

rmarcuzzi commented 1 year ago

Hi! Sorry, I pointed to the wrong repo. We built the MinkowskiNet following the spvnas repo but replaced it with MinkowskiEngine as done by segcontrast and used their pretrained weights. Since they call the state_dict as model and they don't have a batchnorm for each intermediate feature level, you have to modify your code as follows (in the train_model.py script):

state_dict = torch.load("checkpoints/lastepoch199_model_segment_contrast.pt", map_location='cpu')
model.backbone.load_state_dict(w["model"], strict=False)

If you remove the strict=False, you'll see that the checkpoint doesn't include the batchnorm layers.

Sorry again and I hope this helps!

yuyang-cloud commented 1 year ago

Thanks for the quick reply!I still have two questions:

  1. Is the weights 100% labels Fine-tuned semantic segmentation of SegContrast you used? I downloaded it and got two checkpoints: epoch14_model_segment_contrast_1p0.pt and epoch14_model_head_segment_contrast_1p0.pt . Should I use only the first one to initialize the MinkUNet (w/o model.bakcbone.sem_head), or should I use both to load the weights of model.bakcbone.sem_head at the same time?

  2. In addition, whether the backbone weight is frozen or unfrozen during MaskPLS training? If it is unfrozen, is the learning rate of backbone the same as the entire model?

Thanks again for your kind reply!

rmarcuzzi commented 1 year ago

Hi! 1) I used just the pre-trained weights with no fine-tuning. I think I tried both and the difference was not so big. I only used the weights for the network and not for the semantic head.

2) The backbone is unfrozen since we want the network to learn meaningful multi-level features. If you want to use the backbone frozen (maybe because the GPU is too small) you could try also just using the weights of the last layer instead of the multi-level.

I hope this helps! I did the training quite some time ago so it might be that I don't remember everything exactly but that's the main idea that I followed.

yuyang-cloud commented 1 year ago

Ok! Thanks for your reply, and I will have a try. I'll feed back my results here later for further discussion.