open-mmlab / mmdetection

OpenMMLab Detection Toolbox and Benchmark
https://mmdetection.readthedocs.io
Apache License 2.0
29.29k stars 9.42k forks source link

Issue with 'inference_detector' in MMDetection #6891

Open SourabhRanade opened 2 years ago

SourabhRanade commented 2 years ago

I am trying to work with the Mask RCNN with SWIN Transformer as the backbone and have tried some changes to the model (using quantization/pruning, etc) . All of these work fine and I can see the required changes in my model and now I wanted to run an inference with the same on a single image.

I initially load/initialize the model using the below code :

!wget -c https://download.openmmlab.com/mmdetection/v2.0/swin/mask_rcnn_swin-t-p4-w7_fpn_1x_coco/mask_rcnn_swin-t-p4-w7_fpn_1x_coco_20210902_120937-9d6b7cfa.pth \
   -O checkpoints/mask_rcnn_swin-t-p4-w7_fpn_1x_coco_20210902_120937-9d6b7cfa.pth

from mmdet.apis import init_detector, inference_detector

config='/content/mmdetection/configs/swin/mask_rcnn_swin-t-p4-w7_fpn_1x_coco.py'
checkpoint = 'checkpoints/mask_rcnn_swin-t-p4-w7_fpn_1x_coco_20210902_120937-9d6b7cfa.pth'
model = init_detector(config,checkpoint)

The code that I am trying to run the inference with is :

config_file = '/content/mmdetection/configs/swin/mask_rcnn_swin-t-p4-w7_fpn_1x_coco.py'
checkpoint_file = '/content/mmdetection/checkpoints/model_complete.pth'

model = init_detector(config_file, checkpoint_file, device='cpu')

img = '/content/mmdetection/Datasets/cocodataset/val2017/val2017/000000100274.jpg'  # or img = mmcv.imread(img), which will only load it once
result = inference_detector(model, img)
model.show_result(img, result)
model.show_result(img, result, out_file='result.jpg')

The above code runs just fine when I try to use the checkpoint file 'mask_rcnn_swin-t-p4-w7_fpn_1x_coco_20210902_120937-9d6b7cfa.pth' that I used to initialize the model in the first part of the code snippet. However, when I try to use the models that I have created/changed it does not seem to work. It throws an error saying : "RuntimeError: No state_dict found in checkpoint file /content/mmdetection/checkpoints/model_complete.pth" I have tried to work with state_dicts as well but then it throws warnings that suggest that it cannot find the classes for this.

On comparing the checkpoint file that gets downloaded in the initial part and the way my state_dict or the complete model gets saved using Pytorch are totally different. Is there a specific way that I should be saving my model or is there a workaround for the same? Any help with this would be great.

SourabhRanade commented 2 years ago

I see this thread and is very similar to what I want,but I see no concrete solution in the same: https://github.com/open-mmlab/mmdetection/issues/5354

hhaAndroid commented 2 years ago

@SourabhRanade The saved model should use the state_dict field to save the parameters, you can confirm it.

SourabhRanade commented 2 years ago

@hhaAndroid, yes I tried to save my model both ways like : torch.save(transformed_model,path_to_save) as well as torch.save(transformed_model.state_dict(),path_to_save). However, when I try to do the inference on these, I keep getting errors such as : 'key mismatch' and 'classes not saved in checkpoints file' (I have it on my collab notebook if you want I can share the exact error messages.)

On checking the checkpoints file that I get from MMdetection, there is a lot of meta information that gets saved additionally like : 'meta': {'mmdet_version': '2.19.19e8b14d', 'CLASSES': ('person',), 'env_info': 'sys.platform: linux\nPython: 3.7.12 and so on.

This is missing when I save it using the above PyTorch method of saving the model. So I was wondering how can I create my own checkpoints file on the model that I have changed/transformed in a way that can be used directly by MMDetection for inference.

hhaAndroid commented 2 years ago

@SourabhRanade You can read the weight before modification, and then replace the internal state_dict of the domination, and the other places remain unchanged.

SourabhRanade commented 2 years ago

@hhaAndroid so if I understand it correctly, I read my modified weights after the various transformations and replace the same in the state_dict file provided by MMDetection? Is that the approach you are suggesting?

SourabhRanade commented 2 years ago

@hhaAndroid could you please confirm if my understanding of your comment is correct? Thank you