fregu856 / deeplabv3

PyTorch implementation of DeepLabV3, trained on the Cityscapes dataset.
http://www.fregu856.com/
MIT License
768 stars 180 forks source link

size mismatch for aspp when loading pre-trained model #5

Closed LindaSt closed 5 years ago

LindaSt commented 5 years ago

Hi!

Very nice repo! I'm currently trying to integrate your model into our framework (https://github.com/DIVA-DIA/DeepDIVA, feel free to check it out!). However, when I load the provided weights for deeplabv3 I get the following error:

    size mismatch for aspp.conv_1x1_4.weight: copying a param with shape torch.Size([20, 256, 1, 1]) from checkpoint, the shape in current model is torch.Size([8, 256, 1, 1]).
    size mismatch for aspp.conv_1x1_4.bias: copying a param with shape torch.Size([20]) from checkpoint, the shape in current model is torch.Size([8]). (deeplabv3.py:50)

I am using exactly the Resnet (ResNet18_OS8) and the ASPP (no bottleneck) that you are using in your code. Do you know what could be causing this?

Thank you very much already in advance.

Cheers, Linda

fregu856 commented 5 years ago

Hi,

20 is the number of classes i trained the model for (self.num_classes = 20 in https://github.com/fregu856/deeplabv3/blob/master/model/deeplabv3.py), it seems like you have set this to 8 instead?

Regards

// Fredrik

LindaSt commented 5 years ago

I think I just realized what the problem is. I am trying to do transfer learning and have specified a different number of classes than in your dataset. That is what is causing the mismatch.

fregu856 commented 5 years ago

Yes, exactly! You could probably load the pre-trained model into a placeholder model with 20 classes, then loop through the parameters (but skip aspp.conv_1x1_4.weight and aspp.conv_1x1_4.bias) and copy their values to your model.

LindaSt commented 5 years ago

Thank you so much for your quick response! Yes, my model has a different number of classes.

Somewhat unrelated question: the pre-trained weights provided for the resnets are for the cityscape model aswell, or from something else?

fregu856 commented 5 years ago

The pre-trained resnets are taken straight from torchvision, i.e., downloaded from the urls at the top of https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py.