Open oke-aditya opened 4 years ago
Hello @oke-aditya.
Can I work on this?
If yes, can you give me a little more info on what needs to be done exactly? Because I am not sure what you guys did for classification.
This is really tricky. Let me explain in bit detail.
This is actually something we should look in a longer run for major refactor.
For Classification Torchvision provides CNNs (backbones) trained over imagenet.
We extended this to use any backbone, trained which we used from other hub models. E.g. we can now use ssl
etc.
For this I created a dictonary in pretrained folder and simply load these models from urls.
Instantiate the model with NO pretrained weights and use these as needed.
COCO
with backbones trained on imagenet
for resnet50_fpn
Detection has tremendous configurations.
ssl
etc.What currently the detection API works like
from quickvision.models.detection.faster_rcnn import create_fasterrcnn_backbone
frcnn_bbone1 = create_fastercnn_backbone(backbone="resnet50", fpn=False, pretrained="ssl")
frcnn_bbone2 = create_fastercnn_backbone(backbone="resnet50", fpn=False, pretrained="imagenet")
frcnn_model1 = create_vision_fastercnn(num_classes=10, frcnn_bbone1)
frcnn_model2 = create_vision_fastercnn(num_classes=10, frcnn_bbone1)
See that this creates a frcnn model without FPNs but it supports other pre trained backbones.
For FPNs we use torchvision's resnet_fpn
which creates backbones only on "imagenet" with FPNs.
backbone = resnet_fpn_backbone(backbone, pretrained=True,
trainable_layers=trainable_backbone_layers, **kwargs)
For that after we create the FRCNN model, we need to load the COCO weights. Copying code from torchvision
if pretrained:
# no need to download the backbone if pretrained is set
pretrained_backbone = False
backbone = resnet_fpn_backbone('resnet50', pretrained_backbone, trainable_layers=trainable_backbone_layers)
model = FasterRCNN(backbone, num_classes, **kwargs)
if pretrained:
state_dict = load_state_dict_from_url(model_urls['fasterrcnn_resnet50_fpn_coco'],
progress=progress)
model.load_state_dict(state_dict)
That's how we get Resnet50_fpn over COCO.
In short. We need to support the following
P.S. Let me start an initial refactor, It will get clear with that.
I got the gist of what should be done and with an initial refactor it will be more clear, thanks.
@oke-aditya This is what I understood from your previous comment, correct me if I am wrong:
For Resnet FPN Backbones on other weights, you would want something like below but resnet_fpn_backbone
only supports imagenet
.
frcnn_bbone = create_fastercnn_backbone(backbone="resnet50", fpn=True, pretrained="ssl")
For COCO Based models for all Resnet FPNs, you would want something like below and the code you mentioned will work in this case.
frcnn_bbone = create_fastercnn_backbone(backbone="resnet50", fpn=True, pretrained="coco")
Hmm, let me start an initial refactor. This refactor is little tricky.
These two are super hard to support.
Resnet FPN Backbones on other weights.
hardcoding
by torchvision in resnet_fpn_backbone
code.COCO Based models for all Resnet FPNs
resnet_50fpn
. If people contribute such models then we can easily add them, by modifying the backbone
code.Then 2nd feature is quite possible, but we need training. If people can provide them then it would be great.
The above PR, reduces this urgency by sometime. There can be better solution but we need training for most weights.
🚀 Feature
Similar to what we did for classification, probably we should provide something for detection.
This will allow to load pretrained weights from Kitty, COCO, etc. datasets.