pytorch / vision

Datasets, Transforms and Models specific to Computer Vision
https://pytorch.org/vision
BSD 3-Clause "New" or "Revised" License
15.99k stars 6.92k forks source link

`fasterrcnn_resnet50_fpn` edits weights on-the-fly #8502

Open eflorico opened 2 months ago

eflorico commented 2 months ago

🐛 Describe the bug

fasterrcnn_resnet50_fpn() includes code to edit weights on-the-fly while loading them: https://github.com/pytorch/vision/blob/bf01bab6125c5f1152e4f336b470399e52a8559d/torchvision/models/detection/faster_rcnn.py#L578-578

Therefore, if you attempt to load weights after model instantiation, you will get wrong (but only slightly wrong!) weights:

# First:
model = fasterrcnn_resnet50_fpn()

# Later:
weights = FasterRCNN_ResNet50_FPN_Weights.COCO_V1
model.load_state_dict(weights.get_state_dict())

# Resulting weights will be wrong!

Instead, I would suggest applying the weight edits to the weights themselves, so that you can load e.g. FasterRCNN_ResNet50_FPN_Weights.COCO_V1_FIXED using load_state_dict, without having to dig into the torchvision source and manually editing weights yourself.

Versions

Torchvision 0.18.1

NicolasHug commented 1 month ago

Thanks for the report @eflorico . Instead of updating new weights, we can probably overwrite the load_state_dict() method of those respective base classes? There are only a handful of those so this should be tractable.