PRBonn / MaskPLS

Mask-Based Panoptic LiDAR Segmentation for Autonomous Driving, RA-L, 2023
MIT License
54 stars 8 forks source link

After 100 epochs, the model still cannot achieve the metrics of pretrained model #7

Closed comradexy closed 1 year ago

comradexy commented 1 year ago

Hi, I used your default config to train the model. After about 40 epochs, overfitting began to occur. After 100 epochs, the best reslut of it still cannot achieve the metrics of pratrained model, and the PQ even differs by around 4%. Are there any settings that I need to change?

My trained model results:

DATALOADER:0 VALIDATE RESULTS {'metrics/iou': 0.5610383749008179, 'metrics/pq': 0.5602813959121704, 'metrics/rq': 0.6541024446487427}

Pretrained model results:

DATALOADER:0 VALIDATE RESULTS {'metrics/iou': 0.6184651255607605, 'metrics/pq': 0.5990720987319946, 'metrics/rq': 0.6913090944290161}

rmarcuzzi commented 1 year ago

Hi! We leveraged the weights of the feature extractor provided here.

yuyang-cloud commented 1 year ago

Hi! We leveraged the weights of the feature extractor provided here.

Hi! @rmarcuzzi @comradexy I downloaded the backbone weights at SemanticKITTI_val_MinkUNet@114GMACs and used the code below to load the pre-trained weights:

state_dict =  torch.load('../ckpts/backbone.pth', map_location='cpu')
self.backbone.load_state_dict(state_dict )

But there's a Missing keys and Unexpected keys problem:

*** RuntimeError: Error(s) in loading state_dict for MinkEncoderDecoder:
        Missing key(s) in state_dict: "stem.0.kernel", "stem.1.bn.weight", "stem.1.bn.bias", "stem.1.bn.running_mean", "stem.1.bn.running_var", "stem.3.kernel", "stem.4.bn.weight", "stem.4.bn.bias", "stem.4.bn.running_mean", "stem.4.bn.running_var", "stage1.0.net.1.bn.weight", "stage1.0.net.1.bn.bias", "stage1.0.net.1.bn.running_mean", "stage1.0.net.1.bn.running_var", "stage1.1.net.1.bn.weight", "stage1.1.net.1.bn.bias", "stage1.1.net.1.bn.running_mean", "stage1.1.net.1.bn.running_var", "stage1.1.net.4.bn.weight", "stage1.1.net.4.bn.bias", "stage1.1.net.4.bn.running_mean", "stage1.1.net.4.bn.running_var", "stage1.2.net.1.bn.weight", "stage1.2.net.1.bn.bias", "stage1.2.net.1.bn.running_mean", "stage1.2.net.1.bn.running_var", "stage1.2.net.4.bn.weight", "stage1.2.net.4.bn.bias", "stage1.2.net.4.bn.running_mean", "stage1.2.net.4.bn.running_var", "stage2.0.net.1.bn.weight", "stage2.0.net.1.bn.bias", "stage2.0.net.1.bn.running_mean", "stage2.0.net.1.bn.running_var", "stage2.1.net.1.bn.weight", "stage2.1.net.1.bn.bias", "stage2.1.net.1.bn.running_mean", "stage2.1.net.1.bn.running_var", "stage2.1.net.4.bn.weight", "stage2.1.net.4.bn.bias", "stage2.1.net.4.bn.running_mean", "stage2.1.net.4.bn.running_var", "stage2.1.downsample.1.bn.weight", "stage2.1.downsample.1.bn.bias", "stage2.1.downsample.1.bn.running_mean", "stage2.1.downsample.1.bn.running_var", "stage2.2.net.1.bn.weight", "stage2.2.net.1.bn.bias", "stage2.2.net.1.bn.running_mean", "stage2.2.net.1.bn.running_var", "stage2.2.net.4.bn.weight", "stage2.2.net.4.bn.bias", "stage2.2.net.4.bn.running_mean", "stage2.2.net.4.bn.running_var", "stage3.0.net.1.bn.weight", "stage3.0.net.1.bn.bias", "stage3.0.net.1.bn.running_mean", "stage3.0.net.1.bn.running_var", "stage3.1.net.1.bn.weight", "stage3.1.net.1.bn.bias", "stage3.1.net.1.bn.running_mean", "stage3.1.net.1.bn.running_var", "stage3.1.net.4.bn.weight", "stage3.1.net.4.bn.bias", "stage3.1.net.4.bn.running_mean", "stage3.1.net.4.bn.running_var", "stage3.1.downsample.1.bn.weight", "stage3.1.downsample.1.bn.bias", "stage3.1.downsample.1.bn.running_mean", "stage3.1.downsample.1.bn.running_var", "stage3.2.net.1.bn.weight", "stage3.2.net.1.bn.bias", "stage3.2.net.1.bn.running_mean", "stage3.2.net.1.bn.running_var", "stage3.2.net.4.bn.weight", "stage3.2.net.4.bn.bias", "stage3.2.net.4.bn.running_mean", "stage3.2.net.4.bn.running_var", "stage4.0.net.1.bn.weight", "stage4.0.net.1.bn.bias", "stage4.0.net.1.bn.running_mean", "stage4.0.net.1.bn.running_var", "stage4.1.net.1.bn.weight", "stage4.1.net.1.bn.bias", "stage4.1.net.1.bn.running_mean", "stage4.1.net.1.bn.running_var", "stage4.1.net.4.bn.weight", "stage4.1.net.4.bn.bias", "stage4.1.net.4.bn.running_mean", "stage4.1.net.4.bn.running_var", "stage4.1.downsample.1.bn.weight", "stage4.1.downsample.1.bn.bias", "stage4.1.downsample.1.bn.running_mean", "stage4.1.downsample.1.bn.running_var", "stage4.2.net.1.bn.weight", "stage4.2.net.1.bn.bias", "stage4.2.net.1.bn.running_mean", "stage4.2.net.1.bn.running_var", "stage4.2.net.4.bn.weight", "stage4.2.net.4.bn.bias", "stage4.2.net.4.bn.running_mean", "stage4.2.net.4.bn.running_var", "up1.0.net.1.bn.weight", "up1.0.net.1.bn.bias", "up1.0.net.1.bn.running_mean", "up1.0.net.1.bn.running_var", "up1.1.0.net.1.bn.weight", "up1.1.0.net.1.bn.bias", "up1.1.0.net.1.bn.running_mean", "up1.1.0.net.1.bn.running_var", "up1.1.0.net.4.bn.weight", "up1.1.0.net.4.bn.bias", "up1.1.0.net.4.bn.running_mean", "up1.1.0.net.4.bn.running_var", "up1.1.0.downsample.1.bn.weight", "up1.1.0.downsample.1.bn.bias", "up1.1.0.downsample.1.bn.running_mean", "up1.1.0.downsample.1.bn.running_var", "up1.1.1.net.1.bn.weight", "up1.1.1.net.1.bn.bias", "up1.1.1.net.1.bn.running_mean", "up1.1.1.net.1.bn.running_var", "up1.1.1.net.4.bn.weight", "up1.1.1.net.4.bn.bias", "up1.1.1.net.4.bn.running_mean", "up1.1.1.net.4.bn.running_var", "up2.0.net.1.bn.weight", "up2.0.net.1.bn.bias", "up2.0.net.1.bn.running_mean", "up2.0.net.1.bn.running_var", "up2.1.0.net.1.bn.weight", "up2.1.0.net.1.bn.bias", "up2.1.0.net.1.bn.running_mean", "up2.1.0.net.1.bn.running_var", "up2.1.0.net.4.bn.weight", "up2.1.0.net.4.bn.bias", "up2.1.0.net.4.bn.running_mean", "up2.1.0.net.4.bn.running_var", "up2.1.0.downsample.1.bn.weight", "up2.1.0.downsample.1.bn.bias", "up2.1.0.downsample.1.bn.running_mean", "up2.1.0.downsample.1.bn.running_var", "up2.1.1.net.1.bn.weight", "up2.1.1.net.1.bn.bias", "up2.1.1.net.1.bn.running_mean", "up2.1.1.net.1.bn.running_var", "up2.1.1.net.4.bn.weight", "up2.1.1.net.4.bn.bias", "up2.1.1.net.4.bn.running_mean", "up2.1.1.net.4.bn.running_var", "up3.0.net.1.bn.weight", "up3.0.net.1.bn.bias", "up3.0.net.1.bn.running_mean", "up3.0.net.1.bn.running_var", "up3.1.0.net.1.bn.weight", "up3.1.0.net.1.bn.bias", "up3.1.0.net.1.bn.running_mean", "up3.1.0.net.1.bn.running_var", "up3.1.0.net.4.bn.weight", "up3.1.0.net.4.bn.bias", "up3.1.0.net.4.bn.running_mean", "up3.1.0.net.4.bn.running_var", "up3.1.0.downsample.1.bn.weight", "up3.1.0.downsample.1.bn.bias", "up3.1.0.downsample.1.bn.running_mean", "up3.1.0.downsample.1.bn.running_var", "up3.1.1.net.1.bn.weight", "up3.1.1.net.1.bn.bias", "up3.1.1.net.1.bn.running_mean", "up3.1.1.net.1.bn.running_var", "up3.1.1.net.4.bn.weight", "up3.1.1.net.4.bn.bias", "up3.1.1.net.4.bn.running_mean", "up3.1.1.net.4.bn.running_var", "up4.0.net.1.bn.weight", "up4.0.net.1.bn.bias", "up4.0.net.1.bn.running_mean", "up4.0.net.1.bn.running_var", "up4.1.0.net.1.bn.weight", "up4.1.0.net.1.bn.bias", "up4.1.0.net.1.bn.running_mean", "up4.1.0.net.1.bn.running_var", "up4.1.0.net.4.bn.weight", "up4.1.0.net.4.bn.bias", "up4.1.0.net.4.bn.running_mean", "up4.1.0.net.4.bn.running_var", "up4.1.0.downsample.1.bn.weight", "up4.1.0.downsample.1.bn.bias", "up4.1.0.downsample.1.bn.running_mean", "up4.1.0.downsample.1.bn.running_var", "up4.1.1.net.1.bn.weight", "up4.1.1.net.1.bn.bias", "up4.1.1.net.1.bn.running_mean", "up4.1.1.net.1.bn.running_var", "up4.1.1.net.4.bn.weight", "up4.1.1.net.4.bn.bias", "up4.1.1.net.4.bn.running_mean", "up4.1.1.net.4.bn.running_var", "sem_head.weight", "sem_head.bias", "out_bnorm.0.weight", "out_bnorm.0.bias", "out_bnorm.0.running_mean", "out_bnorm.0.running_var", "out_bnorm.1.weight", "out_bnorm.1.bias", "out_bnorm.1.running_mean", "out_bnorm.1.running_var", "out_bnorm.2.weight", "out_bnorm.2.bias", "out_bnorm.2.running_mean", "out_bnorm.2.running_var", "out_bnorm.3.weight", "out_bnorm.3.bias", "out_bnorm.3.running_mean", "out_bnorm.3.running_var".
        Unexpected key(s) in state_dict: "stage1.0.net.1.weight", "stage1.0.net.1.bias", "stage1.0.net.1.running_mean", "stage1.0.net.1.running_var", "stage1.0.net.1.num_batches_tracked", "stage1.1.net.1.weight", "stage1.1.net.1.bias", "stage1.1.net.1.running_mean", "stage1.1.net.1.running_var", "stage1.1.net.1.num_batches_tracked", "stage1.1.net.4.weight", "stage1.1.net.4.bias", "stage1.1.net.4.running_mean", "stage1.1.net.4.running_var", "stage1.1.net.4.num_batches_tracked", "stage1.2.net.1.weight", "stage1.2.net.1.bias", "stage1.2.net.1.running_mean", "stage1.2.net.1.running_var", "stage1.2.net.1.num_batches_tracked", "stage1.2.net.4.weight", "stage1.2.net.4.bias", "stage1.2.net.4.running_mean", "stage1.2.net.4.running_var", "stage1.2.net.4.num_batches_tracked", "stage2.0.net.1.weight", "stage2.0.net.1.bias", "stage2.0.net.1.running_mean", "stage2.0.net.1.running_var", "stage2.0.net.1.num_batches_tracked", "stage2.1.net.1.weight", "stage2.1.net.1.bias", "stage2.1.net.1.running_mean", "stage2.1.net.1.running_var", "stage2.1.net.1.num_batches_tracked", "stage2.1.net.4.weight", "stage2.1.net.4.bias", "stage2.1.net.4.running_mean", "stage2.1.net.4.running_var", "stage2.1.net.4.num_batches_tracked", "stage2.1.downsample.1.weight", "stage2.1.downsample.1.bias", "stage2.1.downsample.1.running_mean", "stage2.1.downsample.1.running_var", "stage2.1.downsample.1.num_batches_tracked", "stage2.2.net.1.weight", "stage2.2.net.1.bias", "stage2.2.net.1.running_mean", "stage2.2.net.1.running_var", "stage2.2.net.1.num_batches_tracked", "stage2.2.net.4.weight", "stage2.2.net.4.bias", "stage2.2.net.4.running_mean", "stage2.2.net.4.running_var", "stage2.2.net.4.num_batches_tracked", "stage3.0.net.1.weight", "stage3.0.net.1.bias", "stage3.0.net.1.running_mean", "stage3.0.net.1.running_var", "stage3.0.net.1.num_batches_tracked", "stage3.1.net.1.weight", "stage3.1.net.1.bias", "stage3.1.net.1.running_mean", "stage3.1.net.1.running_var", "stage3.1.net.1.num_batches_tracked", "stage3.1.net.4.weight", "stage3.1.net.4.bias", "stage3.1.net.4.running_mean", "stage3.1.net.4.running_var", "stage3.1.net.4.num_batches_tracked", "stage3.1.downsample.1.weight", "stage3.1.downsample.1.bias", "stage3.1.downsample.1.running_mean", "stage3.1.downsample.1.running_var", "stage3.1.downsample.1.num_batches_tracked", "stage3.2.net.1.weight", "stage3.2.net.1.bias", "stage3.2.net.1.running_mean", "stage3.2.net.1.running_var", "stage3.2.net.1.num_batches_tracked", "stage3.2.net.4.weight", "stage3.2.net.4.bias", "stage3.2.net.4.running_mean", "stage3.2.net.4.running_var", "stage3.2.net.4.num_batches_tracked", "stage4.0.net.1.weight", "stage4.0.net.1.bias", "stage4.0.net.1.running_mean", "stage4.0.net.1.running_var", "stage4.0.net.1.num_batches_tracked", "stage4.1.net.1.weight", "stage4.1.net.1.bias", "stage4.1.net.1.running_mean", "stage4.1.net.1.running_var", "stage4.1.net.1.num_batches_tracked", "stage4.1.net.4.weight", "stage4.1.net.4.bias", "stage4.1.net.4.running_mean", "stage4.1.net.4.running_var", "stage4.1.net.4.num_batches_tracked", "stage4.1.downsample.1.weight", "stage4.1.downsample.1.bias", "stage4.1.downsample.1.running_mean", "stage4.1.downsample.1.running_var", "stage4.1.downsample.1.num_batches_tracked", "stage4.2.net.1.weight", "stage4.2.net.1.bias", "stage4.2.net.1.running_mean", "stage4.2.net.1.running_var", "stage4.2.net.1.num_batches_tracked", "stage4.2.net.4.weight", "stage4.2.net.4.bias", "stage4.2.net.4.running_mean", "stage4.2.net.4.running_var", "stage4.2.net.4.num_batches_tracked", "up1.0.net.1.weight", "up1.0.net.1.bias", "up1.0.net.1.running_mean", "up1.0.net.1.running_var", "up1.0.net.1.num_batches_tracked", "up1.1.0.net.1.weight", "up1.1.0.net.1.bias", "up1.1.0.net.1.running_mean", "up1.1.0.net.1.running_var", "up1.1.0.net.1.num_batches_tracked", "up1.1.0.net.4.weight", "up1.1.0.net.4.bias", "up1.1.0.net.4.running_mean", "up1.1.0.net.4.running_var", "up1.1.0.net.4.num_batches_tracked", "up1.1.0.downsample.1.weight", "up1.1.0.downsample.1.bias", "up1.1.0.downsample.1.running_mean", "up1.1.0.downsample.1.running_var", "up1.1.0.downsample.1.num_batches_tracked", "up1.1.1.net.1.weight", "up1.1.1.net.1.bias", "up1.1.1.net.1.running_mean", "up1.1.1.net.1.running_var", "up1.1.1.net.1.num_batches_tracked", "up1.1.1.net.4.weight", "up1.1.1.net.4.bias", "up1.1.1.net.4.running_mean", "up1.1.1.net.4.running_var", "up1.1.1.net.4.num_batches_tracked", "up2.0.net.1.weight", "up2.0.net.1.bias", "up2.0.net.1.running_mean", "up2.0.net.1.running_var", "up2.0.net.1.num_batches_tracked", "up2.1.0.net.1.weight", "up2.1.0.net.1.bias", "up2.1.0.net.1.running_mean", "up2.1.0.net.1.running_var", "up2.1.0.net.1.num_batches_tracked", "up2.1.0.net.4.weight", "up2.1.0.net.4.bias", "up2.1.0.net.4.running_mean", "up2.1.0.net.4.running_var", "up2.1.0.net.4.num_batches_tracked", "up2.1.0.downsample.1.weight", "up2.1.0.downsample.1.bias", "up2.1.0.downsample.1.running_mean", "up2.1.0.downsample.1.running_var", "up2.1.0.downsample.1.num_batches_tracked", "up2.1.1.net.1.weight", "up2.1.1.net.1.bias", "up2.1.1.net.1.running_mean", "up2.1.1.net.1.running_var", "up2.1.1.net.1.num_batches_tracked", "up2.1.1.net.4.weight", "up2.1.1.net.4.bias", "up2.1.1.net.4.running_mean", "up2.1.1.net.4.running_var", "up2.1.1.net.4.num_batches_tracked", "up3.0.net.1.weight", "up3.0.net.1.bias", "up3.0.net.1.running_mean", "up3.0.net.1.running_var", "up3.0.net.1.num_batches_tracked", "up3.1.0.net.1.weight", "up3.1.0.net.1.bias", "up3.1.0.net.1.running_mean", "up3.1.0.net.1.running_var", "up3.1.0.net.1.num_batches_tracked", "up3.1.0.net.4.weight", "up3.1.0.net.4.bias", "up3.1.0.net.4.running_mean", "up3.1.0.net.4.running_var", "up3.1.0.net.4.num_batches_tracked", "up3.1.0.downsample.1.weight", "up3.1.0.downsample.1.bias", "up3.1.0.downsample.1.running_mean", "up3.1.0.downsample.1.running_var", "up3.1.0.downsample.1.num_batches_tracked", "up3.1.1.net.1.weight", "up3.1.1.net.1.bias", "up3.1.1.net.1.running_mean", "up3.1.1.net.1.running_var", "up3.1.1.net.1.num_batches_tracked", "up3.1.1.net.4.weight", "up3.1.1.net.4.bias", "up3.1.1.net.4.running_mean", "up3.1.1.net.4.running_var", "up3.1.1.net.4.num_batches_tracked", "up4.0.net.1.weight", "up4.0.net.1.bias", "up4.0.net.1.running_mean", "up4.0.net.1.running_var", "up4.0.net.1.num_batches_tracked", "up4.1.0.net.1.weight", "up4.1.0.net.1.bias", "up4.1.0.net.1.running_mean", "up4.1.0.net.1.running_var", "up4.1.0.net.1.num_batches_tracked", "up4.1.0.net.4.weight", "up4.1.0.net.4.bias", "up4.1.0.net.4.running_mean", "up4.1.0.net.4.running_var", "up4.1.0.net.4.num_batches_tracked", "up4.1.0.downsample.1.weight", "up4.1.0.downsample.1.bias", "up4.1.0.downsample.1.running_mean", "up4.1.0.downsample.1.running_var", "up4.1.0.downsample.1.num_batches_tracked", "up4.1.1.net.1.weight", "up4.1.1.net.1.bias", "up4.1.1.net.1.running_mean", "up4.1.1.net.1.running_var", "up4.1.1.net.1.num_batches_tracked", "up4.1.1.net.4.weight", "up4.1.1.net.4.bias", "up4.1.1.net.4.running_mean", "up4.1.1.net.4.running_var", "up4.1.1.net.4.num_batches_tracked".

How do you load pre-trained weights and are the backbone weights frozen during training MaskPLS? Is it convenient to provide the relevant code so that we can reproduce the results? Thanks in advance!

yuyang-cloud commented 1 year ago

Hi! We leveraged the weights of the feature extractor provided here.

Hi! @rmarcuzzi @comradexy I downloaded the backbone weights at SemanticKITTI_val_MinkUNet@114GMACs and used the code below to load the pre-trained weights:

state_dict =  torch.load('../ckpts/backbone.pth', map_location='cpu')
self.backbone.load_state_dict(state_dict )

But there's a Missing keys and Unexpected keys problem:

*** RuntimeError: Error(s) in loading state_dict for MinkEncoderDecoder:
        Missing key(s) in state_dict: "stem.0.kernel", "stem.1.bn.weight", "stem.1.bn.bias", "stem.1.bn.running_mean", "stem.1.bn.running_var", "stem.3.kernel", "stem.4.bn.weight", "stem.4.bn.bias", "stem.4.bn.running_mean", "stem.4.bn.running_var", "stage1.0.net.1.bn.weight", "stage1.0.net.1.bn.bias", "stage1.0.net.1.bn.running_mean", "stage1.0.net.1.bn.running_var", "stage1.1.net.1.bn.weight", "stage1.1.net.1.bn.bias", "stage1.1.net.1.bn.running_mean", "stage1.1.net.1.bn.running_var", "stage1.1.net.4.bn.weight", "stage1.1.net.4.bn.bias", "stage1.1.net.4.bn.running_mean", "stage1.1.net.4.bn.running_var", "stage1.2.net.1.bn.weight", "stage1.2.net.1.bn.bias", "stage1.2.net.1.bn.running_mean", "stage1.2.net.1.bn.running_var", "stage1.2.net.4.bn.weight", "stage1.2.net.4.bn.bias", "stage1.2.net.4.bn.running_mean", "stage1.2.net.4.bn.running_var", "stage2.0.net.1.bn.weight", "stage2.0.net.1.bn.bias", "stage2.0.net.1.bn.running_mean", "stage2.0.net.1.bn.running_var", "stage2.1.net.1.bn.weight", "stage2.1.net.1.bn.bias", "stage2.1.net.1.bn.running_mean", "stage2.1.net.1.bn.running_var", "stage2.1.net.4.bn.weight", "stage2.1.net.4.bn.bias", "stage2.1.net.4.bn.running_mean", "stage2.1.net.4.bn.running_var", "stage2.1.downsample.1.bn.weight", "stage2.1.downsample.1.bn.bias", "stage2.1.downsample.1.bn.running_mean", "stage2.1.downsample.1.bn.running_var", "stage2.2.net.1.bn.weight", "stage2.2.net.1.bn.bias", "stage2.2.net.1.bn.running_mean", "stage2.2.net.1.bn.running_var", "stage2.2.net.4.bn.weight", "stage2.2.net.4.bn.bias", "stage2.2.net.4.bn.running_mean", "stage2.2.net.4.bn.running_var", "stage3.0.net.1.bn.weight", "stage3.0.net.1.bn.bias", "stage3.0.net.1.bn.running_mean", "stage3.0.net.1.bn.running_var", "stage3.1.net.1.bn.weight", "stage3.1.net.1.bn.bias", "stage3.1.net.1.bn.running_mean", "stage3.1.net.1.bn.running_var", "stage3.1.net.4.bn.weight", "stage3.1.net.4.bn.bias", "stage3.1.net.4.bn.running_mean", "stage3.1.net.4.bn.running_var", "stage3.1.downsample.1.bn.weight", "stage3.1.downsample.1.bn.bias", "stage3.1.downsample.1.bn.running_mean", "stage3.1.downsample.1.bn.running_var", "stage3.2.net.1.bn.weight", "stage3.2.net.1.bn.bias", "stage3.2.net.1.bn.running_mean", "stage3.2.net.1.bn.running_var", "stage3.2.net.4.bn.weight", "stage3.2.net.4.bn.bias", "stage3.2.net.4.bn.running_mean", "stage3.2.net.4.bn.running_var", "stage4.0.net.1.bn.weight", "stage4.0.net.1.bn.bias", "stage4.0.net.1.bn.running_mean", "stage4.0.net.1.bn.running_var", "stage4.1.net.1.bn.weight", "stage4.1.net.1.bn.bias", "stage4.1.net.1.bn.running_mean", "stage4.1.net.1.bn.running_var", "stage4.1.net.4.bn.weight", "stage4.1.net.4.bn.bias", "stage4.1.net.4.bn.running_mean", "stage4.1.net.4.bn.running_var", "stage4.1.downsample.1.bn.weight", "stage4.1.downsample.1.bn.bias", "stage4.1.downsample.1.bn.running_mean", "stage4.1.downsample.1.bn.running_var", "stage4.2.net.1.bn.weight", "stage4.2.net.1.bn.bias", "stage4.2.net.1.bn.running_mean", "stage4.2.net.1.bn.running_var", "stage4.2.net.4.bn.weight", "stage4.2.net.4.bn.bias", "stage4.2.net.4.bn.running_mean", "stage4.2.net.4.bn.running_var", "up1.0.net.1.bn.weight", "up1.0.net.1.bn.bias", "up1.0.net.1.bn.running_mean", "up1.0.net.1.bn.running_var", "up1.1.0.net.1.bn.weight", "up1.1.0.net.1.bn.bias", "up1.1.0.net.1.bn.running_mean", "up1.1.0.net.1.bn.running_var", "up1.1.0.net.4.bn.weight", "up1.1.0.net.4.bn.bias", "up1.1.0.net.4.bn.running_mean", "up1.1.0.net.4.bn.running_var", "up1.1.0.downsample.1.bn.weight", "up1.1.0.downsample.1.bn.bias", "up1.1.0.downsample.1.bn.running_mean", "up1.1.0.downsample.1.bn.running_var", "up1.1.1.net.1.bn.weight", "up1.1.1.net.1.bn.bias", "up1.1.1.net.1.bn.running_mean", "up1.1.1.net.1.bn.running_var", "up1.1.1.net.4.bn.weight", "up1.1.1.net.4.bn.bias", "up1.1.1.net.4.bn.running_mean", "up1.1.1.net.4.bn.running_var", "up2.0.net.1.bn.weight", "up2.0.net.1.bn.bias", "up2.0.net.1.bn.running_mean", "up2.0.net.1.bn.running_var", "up2.1.0.net.1.bn.weight", "up2.1.0.net.1.bn.bias", "up2.1.0.net.1.bn.running_mean", "up2.1.0.net.1.bn.running_var", "up2.1.0.net.4.bn.weight", "up2.1.0.net.4.bn.bias", "up2.1.0.net.4.bn.running_mean", "up2.1.0.net.4.bn.running_var", "up2.1.0.downsample.1.bn.weight", "up2.1.0.downsample.1.bn.bias", "up2.1.0.downsample.1.bn.running_mean", "up2.1.0.downsample.1.bn.running_var", "up2.1.1.net.1.bn.weight", "up2.1.1.net.1.bn.bias", "up2.1.1.net.1.bn.running_mean", "up2.1.1.net.1.bn.running_var", "up2.1.1.net.4.bn.weight", "up2.1.1.net.4.bn.bias", "up2.1.1.net.4.bn.running_mean", "up2.1.1.net.4.bn.running_var", "up3.0.net.1.bn.weight", "up3.0.net.1.bn.bias", "up3.0.net.1.bn.running_mean", "up3.0.net.1.bn.running_var", "up3.1.0.net.1.bn.weight", "up3.1.0.net.1.bn.bias", "up3.1.0.net.1.bn.running_mean", "up3.1.0.net.1.bn.running_var", "up3.1.0.net.4.bn.weight", "up3.1.0.net.4.bn.bias", "up3.1.0.net.4.bn.running_mean", "up3.1.0.net.4.bn.running_var", "up3.1.0.downsample.1.bn.weight", "up3.1.0.downsample.1.bn.bias", "up3.1.0.downsample.1.bn.running_mean", "up3.1.0.downsample.1.bn.running_var", "up3.1.1.net.1.bn.weight", "up3.1.1.net.1.bn.bias", "up3.1.1.net.1.bn.running_mean", "up3.1.1.net.1.bn.running_var", "up3.1.1.net.4.bn.weight", "up3.1.1.net.4.bn.bias", "up3.1.1.net.4.bn.running_mean", "up3.1.1.net.4.bn.running_var", "up4.0.net.1.bn.weight", "up4.0.net.1.bn.bias", "up4.0.net.1.bn.running_mean", "up4.0.net.1.bn.running_var", "up4.1.0.net.1.bn.weight", "up4.1.0.net.1.bn.bias", "up4.1.0.net.1.bn.running_mean", "up4.1.0.net.1.bn.running_var", "up4.1.0.net.4.bn.weight", "up4.1.0.net.4.bn.bias", "up4.1.0.net.4.bn.running_mean", "up4.1.0.net.4.bn.running_var", "up4.1.0.downsample.1.bn.weight", "up4.1.0.downsample.1.bn.bias", "up4.1.0.downsample.1.bn.running_mean", "up4.1.0.downsample.1.bn.running_var", "up4.1.1.net.1.bn.weight", "up4.1.1.net.1.bn.bias", "up4.1.1.net.1.bn.running_mean", "up4.1.1.net.1.bn.running_var", "up4.1.1.net.4.bn.weight", "up4.1.1.net.4.bn.bias", "up4.1.1.net.4.bn.running_mean", "up4.1.1.net.4.bn.running_var", "sem_head.weight", "sem_head.bias", "out_bnorm.0.weight", "out_bnorm.0.bias", "out_bnorm.0.running_mean", "out_bnorm.0.running_var", "out_bnorm.1.weight", "out_bnorm.1.bias", "out_bnorm.1.running_mean", "out_bnorm.1.running_var", "out_bnorm.2.weight", "out_bnorm.2.bias", "out_bnorm.2.running_mean", "out_bnorm.2.running_var", "out_bnorm.3.weight", "out_bnorm.3.bias", "out_bnorm.3.running_mean", "out_bnorm.3.running_var".
        Unexpected key(s) in state_dict: "stage1.0.net.1.weight", "stage1.0.net.1.bias", "stage1.0.net.1.running_mean", "stage1.0.net.1.running_var", "stage1.0.net.1.num_batches_tracked", "stage1.1.net.1.weight", "stage1.1.net.1.bias", "stage1.1.net.1.running_mean", "stage1.1.net.1.running_var", "stage1.1.net.1.num_batches_tracked", "stage1.1.net.4.weight", "stage1.1.net.4.bias", "stage1.1.net.4.running_mean", "stage1.1.net.4.running_var", "stage1.1.net.4.num_batches_tracked", "stage1.2.net.1.weight", "stage1.2.net.1.bias", "stage1.2.net.1.running_mean", "stage1.2.net.1.running_var", "stage1.2.net.1.num_batches_tracked", "stage1.2.net.4.weight", "stage1.2.net.4.bias", "stage1.2.net.4.running_mean", "stage1.2.net.4.running_var", "stage1.2.net.4.num_batches_tracked", "stage2.0.net.1.weight", "stage2.0.net.1.bias", "stage2.0.net.1.running_mean", "stage2.0.net.1.running_var", "stage2.0.net.1.num_batches_tracked", "stage2.1.net.1.weight", "stage2.1.net.1.bias", "stage2.1.net.1.running_mean", "stage2.1.net.1.running_var", "stage2.1.net.1.num_batches_tracked", "stage2.1.net.4.weight", "stage2.1.net.4.bias", "stage2.1.net.4.running_mean", "stage2.1.net.4.running_var", "stage2.1.net.4.num_batches_tracked", "stage2.1.downsample.1.weight", "stage2.1.downsample.1.bias", "stage2.1.downsample.1.running_mean", "stage2.1.downsample.1.running_var", "stage2.1.downsample.1.num_batches_tracked", "stage2.2.net.1.weight", "stage2.2.net.1.bias", "stage2.2.net.1.running_mean", "stage2.2.net.1.running_var", "stage2.2.net.1.num_batches_tracked", "stage2.2.net.4.weight", "stage2.2.net.4.bias", "stage2.2.net.4.running_mean", "stage2.2.net.4.running_var", "stage2.2.net.4.num_batches_tracked", "stage3.0.net.1.weight", "stage3.0.net.1.bias", "stage3.0.net.1.running_mean", "stage3.0.net.1.running_var", "stage3.0.net.1.num_batches_tracked", "stage3.1.net.1.weight", "stage3.1.net.1.bias", "stage3.1.net.1.running_mean", "stage3.1.net.1.running_var", "stage3.1.net.1.num_batches_tracked", "stage3.1.net.4.weight", "stage3.1.net.4.bias", "stage3.1.net.4.running_mean", "stage3.1.net.4.running_var", "stage3.1.net.4.num_batches_tracked", "stage3.1.downsample.1.weight", "stage3.1.downsample.1.bias", "stage3.1.downsample.1.running_mean", "stage3.1.downsample.1.running_var", "stage3.1.downsample.1.num_batches_tracked", "stage3.2.net.1.weight", "stage3.2.net.1.bias", "stage3.2.net.1.running_mean", "stage3.2.net.1.running_var", "stage3.2.net.1.num_batches_tracked", "stage3.2.net.4.weight", "stage3.2.net.4.bias", "stage3.2.net.4.running_mean", "stage3.2.net.4.running_var", "stage3.2.net.4.num_batches_tracked", "stage4.0.net.1.weight", "stage4.0.net.1.bias", "stage4.0.net.1.running_mean", "stage4.0.net.1.running_var", "stage4.0.net.1.num_batches_tracked", "stage4.1.net.1.weight", "stage4.1.net.1.bias", "stage4.1.net.1.running_mean", "stage4.1.net.1.running_var", "stage4.1.net.1.num_batches_tracked", "stage4.1.net.4.weight", "stage4.1.net.4.bias", "stage4.1.net.4.running_mean", "stage4.1.net.4.running_var", "stage4.1.net.4.num_batches_tracked", "stage4.1.downsample.1.weight", "stage4.1.downsample.1.bias", "stage4.1.downsample.1.running_mean", "stage4.1.downsample.1.running_var", "stage4.1.downsample.1.num_batches_tracked", "stage4.2.net.1.weight", "stage4.2.net.1.bias", "stage4.2.net.1.running_mean", "stage4.2.net.1.running_var", "stage4.2.net.1.num_batches_tracked", "stage4.2.net.4.weight", "stage4.2.net.4.bias", "stage4.2.net.4.running_mean", "stage4.2.net.4.running_var", "stage4.2.net.4.num_batches_tracked", "up1.0.net.1.weight", "up1.0.net.1.bias", "up1.0.net.1.running_mean", "up1.0.net.1.running_var", "up1.0.net.1.num_batches_tracked", "up1.1.0.net.1.weight", "up1.1.0.net.1.bias", "up1.1.0.net.1.running_mean", "up1.1.0.net.1.running_var", "up1.1.0.net.1.num_batches_tracked", "up1.1.0.net.4.weight", "up1.1.0.net.4.bias", "up1.1.0.net.4.running_mean", "up1.1.0.net.4.running_var", "up1.1.0.net.4.num_batches_tracked", "up1.1.0.downsample.1.weight", "up1.1.0.downsample.1.bias", "up1.1.0.downsample.1.running_mean", "up1.1.0.downsample.1.running_var", "up1.1.0.downsample.1.num_batches_tracked", "up1.1.1.net.1.weight", "up1.1.1.net.1.bias", "up1.1.1.net.1.running_mean", "up1.1.1.net.1.running_var", "up1.1.1.net.1.num_batches_tracked", "up1.1.1.net.4.weight", "up1.1.1.net.4.bias", "up1.1.1.net.4.running_mean", "up1.1.1.net.4.running_var", "up1.1.1.net.4.num_batches_tracked", "up2.0.net.1.weight", "up2.0.net.1.bias", "up2.0.net.1.running_mean", "up2.0.net.1.running_var", "up2.0.net.1.num_batches_tracked", "up2.1.0.net.1.weight", "up2.1.0.net.1.bias", "up2.1.0.net.1.running_mean", "up2.1.0.net.1.running_var", "up2.1.0.net.1.num_batches_tracked", "up2.1.0.net.4.weight", "up2.1.0.net.4.bias", "up2.1.0.net.4.running_mean", "up2.1.0.net.4.running_var", "up2.1.0.net.4.num_batches_tracked", "up2.1.0.downsample.1.weight", "up2.1.0.downsample.1.bias", "up2.1.0.downsample.1.running_mean", "up2.1.0.downsample.1.running_var", "up2.1.0.downsample.1.num_batches_tracked", "up2.1.1.net.1.weight", "up2.1.1.net.1.bias", "up2.1.1.net.1.running_mean", "up2.1.1.net.1.running_var", "up2.1.1.net.1.num_batches_tracked", "up2.1.1.net.4.weight", "up2.1.1.net.4.bias", "up2.1.1.net.4.running_mean", "up2.1.1.net.4.running_var", "up2.1.1.net.4.num_batches_tracked", "up3.0.net.1.weight", "up3.0.net.1.bias", "up3.0.net.1.running_mean", "up3.0.net.1.running_var", "up3.0.net.1.num_batches_tracked", "up3.1.0.net.1.weight", "up3.1.0.net.1.bias", "up3.1.0.net.1.running_mean", "up3.1.0.net.1.running_var", "up3.1.0.net.1.num_batches_tracked", "up3.1.0.net.4.weight", "up3.1.0.net.4.bias", "up3.1.0.net.4.running_mean", "up3.1.0.net.4.running_var", "up3.1.0.net.4.num_batches_tracked", "up3.1.0.downsample.1.weight", "up3.1.0.downsample.1.bias", "up3.1.0.downsample.1.running_mean", "up3.1.0.downsample.1.running_var", "up3.1.0.downsample.1.num_batches_tracked", "up3.1.1.net.1.weight", "up3.1.1.net.1.bias", "up3.1.1.net.1.running_mean", "up3.1.1.net.1.running_var", "up3.1.1.net.1.num_batches_tracked", "up3.1.1.net.4.weight", "up3.1.1.net.4.bias", "up3.1.1.net.4.running_mean", "up3.1.1.net.4.running_var", "up3.1.1.net.4.num_batches_tracked", "up4.0.net.1.weight", "up4.0.net.1.bias", "up4.0.net.1.running_mean", "up4.0.net.1.running_var", "up4.0.net.1.num_batches_tracked", "up4.1.0.net.1.weight", "up4.1.0.net.1.bias", "up4.1.0.net.1.running_mean", "up4.1.0.net.1.running_var", "up4.1.0.net.1.num_batches_tracked", "up4.1.0.net.4.weight", "up4.1.0.net.4.bias", "up4.1.0.net.4.running_mean", "up4.1.0.net.4.running_var", "up4.1.0.net.4.num_batches_tracked", "up4.1.0.downsample.1.weight", "up4.1.0.downsample.1.bias", "up4.1.0.downsample.1.running_mean", "up4.1.0.downsample.1.running_var", "up4.1.0.downsample.1.num_batches_tracked", "up4.1.1.net.1.weight", "up4.1.1.net.1.bias", "up4.1.1.net.1.running_mean", "up4.1.1.net.1.running_var", "up4.1.1.net.1.num_batches_tracked", "up4.1.1.net.4.weight", "up4.1.1.net.4.bias", "up4.1.1.net.4.running_mean", "up4.1.1.net.4.running_var", "up4.1.1.net.4.num_batches_tracked".

How do you load pre-trained weights and are the backbone weights frozen during training MaskPLS? Is it convenient to provide the relevant code so that we can reproduce the results? Thanks in advance!

I noticed that the MinkUNet in spvnas used torchsparse.nn.Conv3d, while MaskPLS used MinkowskiEngine.MinkowskiConvolution, so how to leverage the checkpoints provided in spvnas ?

malicd commented 1 year ago

Hi, the training, initialized with the official weights from segcontrast (as per this commend) yield following result:

{'metrics/iou': 0.52241450548172,
 'metrics/pq': 0.5092883110046387,
 'metrics/rq': 0.5975804924964905}
--------------------------------------------------------------------------------
Evaluated 4071 frames. Duplicated frame number: 0
|        |   PQ   |   RQ   |   SQ   |  IoU   |
|all     | 0.5093 | 0.5976 | 0.7055 | 0.5224 |
|nlabeled| 0.0000 | 0.0000 | 0.0000 | 0.0000 |
|car     | 0.8954 | 0.9568 | 0.9358 | 0.9134 |
|bicycle | 0.0000 | 0.0000 | 0.0000 | 0.0002 |
|torcycle| 0.4794 | 0.5087 | 0.9424 | 0.4525 |
|truck   | 0.6470 | 0.6736 | 0.9605 | 0.4909 |
|-vehicle| 0.4458 | 0.4713 | 0.9458 | 0.4570 |
|person  | 0.5407 | 0.6155 | 0.8784 | 0.5120 |
|icyclist| 0.8450 | 0.9072 | 0.9314 | 0.8592 |
|rcyclist| 0.0000 | 0.0000 | 0.0000 | 0.0000 |
|road    | 0.9277 | 0.9998 | 0.9279 | 0.9261 |
|parking | 0.1905 | 0.2750 | 0.6927 | 0.1760 |
|sidewalk| 0.7538 | 0.9147 | 0.8240 | 0.7862 |
|r-ground| 0.0000 | 0.0000 | 0.0000 | 0.0000 |
|building| 0.8368 | 0.9397 | 0.8905 | 0.8572 |
|fence   | 0.1621 | 0.2419 | 0.6703 | 0.2753 |
|getation| 0.8349 | 0.9828 | 0.8495 | 0.8542 |
|trunk   | 0.4716 | 0.6391 | 0.7379 | 0.6130 |
|terrain | 0.5264 | 0.7333 | 0.7178 | 0.6893 |
|pole    | 0.5985 | 0.8111 | 0.7379 | 0.6442 |
|fic-sign| 0.5210 | 0.6835 | 0.7623 | 0.4193 |
pq_mean:        0.5092882834782245
pq_dagger:      0.5312667050067429
sq_mean:        0.7055340863309376
rq_mean:        0.5975804851373808
iou_mean:       0.5224144857864191
pq_stuff:       0.5293838666543684
rq_stuff:       0.6564395114692565
sq_stuff:       0.710073790027907
pq_things:      0.48165685661102653
rq_things:      0.5166493239310516
sq_things:      0.6992919937476048

This is the checkpoint with highest validation PQ (epoch 44 in this case). The latest checkpoint performs worse than this. My environment is exactly the same as requirements.txt.

@rmarcuzzi Is there a way to find out what could be wrong here? Please let me know if you need more information.

rmarcuzzi commented 1 year ago

Hi! That's unexpected actually. As far as I remember, we tried training from scratch and loading the pre-trained weights and the last option gave the best results. We also tried using the weights of the backbone trained for semantic segmentation but the results were not better than using the seg contrast weights. I don't know what else you could try, that is the setup that I remember we used to get the results.

If you want to modify the model and train it, you could use the provided pre-trained weights and just set strict=False when loading the weights in train_model.py:

model.load_state_dict(w["state_dict"], strict=False)

This way, the weights will be loaded for the parts of the network that remain the same.

I hope that helps you and sorry if my instructions didn't provide the expected results but that's what I remember we did.