open-mmlab / mmsegmentation

OpenMMLab Semantic Segmentation Toolbox and Benchmark.
https://mmsegmentation.readthedocs.io/en/main/
Apache License 2.0
8.21k stars 2.6k forks source link

demo/MMSegmentation_Tutorial.ipynb train error #3778

Open zhuxuanya opened 1 month ago

zhuxuanya commented 1 month ago

When I run runner.train() in demo/MMSegmentation_Tutorial.ipynb, this will happen:

09/24 13:50:28 - mmengine - WARNING - The prefix is not set in metric class IoUMetric.
09/24 13:50:28 - mmengine - INFO - load model from: open-mmlab://resnet50_v1c
09/24 13:50:28 - mmengine - INFO - Loads checkpoint by openmmlab backend from path: open-mmlab://resnet50_v1c
09/24 13:50:29 - mmengine - WARNING - The model and loaded state dict do not match exactly

unexpected key in source state_dict: fc.weight, fc.bias

Loads checkpoint by local backend from path: pspnet_r50-d8_512x1024_40k_cityscapes_20200605_003338-2966598c.pth
The model and loaded state dict do not match exactly

size mismatch for decode_head.conv_seg.weight: copying a param with shape torch.Size([19, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([8, 512, 1, 1]).
size mismatch for decode_head.conv_seg.bias: copying a param with shape torch.Size([19]) from checkpoint, the shape in current model is torch.Size([8]).
size mismatch for auxiliary_head.conv_seg.weight: copying a param with shape torch.Size([19, 256, 1, 1]) from checkpoint, the shape in current model is torch.Size([8, 256, 1, 1]).
size mismatch for auxiliary_head.conv_seg.bias: copying a param with shape torch.Size([19]) from checkpoint, the shape in current model is torch.Size([8]).
09/24 13:50:29 - mmengine - INFO - Load checkpoint from pspnet_r50-d8_512x1024_40k_cityscapes_20200605_003338-2966598c.pth
09/24 13:50:29 - mmengine - WARNING - "FileClient" will be deprecated in future. Please use io functions in https://mmengine.readthedocs.io/en/latest/api/fileio.html#file-io
09/24 13:50:29 - mmengine - WARNING - "HardDiskBackend" is the alias of "LocalBackend" and the former will be deprecated in future.
09/24 13:50:29 - mmengine - INFO - Checkpoints will be saved to demo/work_dirs/tutorial.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:73: nll_loss2d_forward_no_reduce_kernel: block: [295,0,0], thread: [259,0,0] Assertion `cur_target >= 0 && cur_target < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:73: nll_loss2d_forward_no_reduce_kernel: block: [68,0,0], thread: [631,0,0] Assertion `cur_target >= 0 && cur_target < n_classes` failed.

It seems that the number of classes of the model doesn't match the loading checkpoint. I notice that the config has been modified in the code block by using:

# modify num classes of the model in decode/auxiliary head
cfg.model.decode_head.num_classes = 8
cfg.model.auxiliary_head.num_classes = 8

This modification doesn't make the checkpoint's state dict match the config exactly. Then I use following code to check the name and parm in the checkpoint:

pth = torch.load('pspnet_r50-d8_512x1024_40k_cityscapes_20200605_003338-2966598c.pth')
state_dict = pth['state_dict']
for name, parm in state_dict.items():
    print(name, parm.size())

What can be seen:

decode_head.conv_seg.weight torch.Size([19, 512, 1, 1])
decode_head.conv_seg.bias torch.Size([19])
...
auxiliary_head.conv_seg.weight torch.Size([19, 256, 1, 1])
auxiliary_head.conv_seg.bias torch.Size([19])

I don't know whether my understanding is correct or not. If it does, then it may add some difficulties to the people who use this MMSegmentation_Tutorial.ipynb, even if it doesn't affect the training framework.

FreeWillThorn commented 2 weeks ago

so do you know how to fix this bug? I am stuck in exactly the same situation

zhuxuanya commented 2 weeks ago

so do you know how to fix this bug? I am stuck in exactly the same situation

I think this can be ignored. No need to make effort to modify model.

MariaMel98 commented 4 days ago

i had the same problem and noticed that some segmentation maps had pixel values different than 0-7 which are the classes indexes. So i changed everyone of them to 0 (I know that it's not technically correct) and it worked fine