pytorch / vision

Datasets, Transforms and Models specific to Computer Vision
https://pytorch.org/vision
BSD 3-Clause "New" or "Revised" License
16.11k stars 6.94k forks source link

Ability to change the retinanet model to be modified after training #6614

Open Monk5088 opened 2 years ago

Monk5088 commented 2 years ago

🚀 The feature

After training the retinanet model, we are not able to change the number of classes for the next training session. For example, if the retinanet model is trained on 56 classes, the classifier subnet of Retinanet outputs 56 hot encoded vector, how can we use the same model weights for the next dataset which only has 40 classes of the previous problem? My issue is similar I trained my retinanet on 3 classes, and now I want to use the same model weights but for a new dataset that has only had the first 2 classes of the previous dataset, when I try to do torch.load(model.state_dict(), PATH) it throws me a classifier mismatch:


/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
   1496         if len(error_msgs) > 0:
   1497             raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
-> 1498                                self.__class__.__name__, "\n\t".join(error_msgs)))
   1499         return _IncompatibleKeys(missing_keys, unexpected_keys)
   1500 

RuntimeError: Error(s) in loading state_dict for RetinaNet:
    size mismatch for classifier.3.weight: copying a param with shape torch.Size([3, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([2, 128, 3, 3]).
    size mismatch for classifier.3.bias: copying a param with shape torch.Size([3]) from checkpoint, the shape in current model is torch.Size([2]).

Motivation, pitch

I'm working on object detection using retinanet and I face the above-mentioned issue.

Alternatives

No response

Additional context

No response

cc @datumbox

datumbox commented 2 years ago

@Monk5088 there is a similar discussion for FCOS at #5932. You might want to adjust this solution of this comment in your case: https://github.com/pytorch/vision/issues/5932#issuecomment-1147940944

Monk5088 commented 2 years ago

Yeah that solution in #5932 looks like it would work but in my case it's a classic Retinanet architecture with Resnet 34 backbone, so i need to remove the normal and constant line in the code, also is there any other modification that i need to do @datumbox ? Here is the link for my implementation of retinanet: https://github.com/ChristianMarzahl/ObjectDetection/blob/master/object_detection_fastai/models/RetinaNet.py . It would be really helpful if you could mention the changes i need to make for the custom FCOS model. Thanks and regards, Harshit