open-mmlab / mmdetection3d

OpenMMLab's next-generation platform for general 3D object detection.
https://mmdetection3d.readthedocs.io/en/latest/
Apache License 2.0
5.32k stars 1.54k forks source link

Transferring a Checkpoint to A New, Larger Model and Loading Part of a Model From a Checkpoint #1439

Closed yaniv-f closed 2 years ago

yaniv-f commented 2 years ago

Hi. I find mmdetection3d very useful! I have two questions:

  1. I’m trying to train a large model which includes the following components in the config: img_backbone=dict( type='ResNet', depth=50, num_stages=4, out_indices=(0, 1, 2, 3), frozen_stages=1, norm_cfg=dict(type='BN', requires_grad=True), norm_eval=True, style='pytorch'), img_neck=dict( type='FPN', in_channels=[256, 512, 1024, 2048], out_channels=256, num_outs=5),

As initialization for the weights of the above parts of my model, I would like to use the checkpoint from mmdetection3d training of nuImages, the model config is at https://github.com/open-mmlab/mmdetection3d/blob/master/configs/nuimages/mask_rcnn_r50_fpn_1x_nuim.py (linked from the first row of the instance segmentation results at https://github.com/open-mmlab/mmdetection3d/blob/master/configs/nuimages/README.md ) and the checkpoint is https://download.openmmlab.com/mmdetection3d/v0.1.0_models/nuimages_semseg/mask_rcnn_r50_fpn_1x_nuim/mask_rcnn_r50_fpn_1x_nuim_20201008_195238-e99f5182.pth

Inside the above .pth checkpoint file, there are model weights such as

import mmcv.runner as runner mmdet_pret = runner.checkpoint.load_from_http('https://download.openmmlab.com/mmdetection3d/v0.1.0_models/nuimages_semseg/mask_rcnn_r50_fpn_1x_nuim/mask_rcnn_r50_fpn_1x_nuim_20201008_195238-e99f5182.pth') mmdet_pret['state_dict'].keys() odict_keys(['backbone.conv1.weight', 'backbone.bn1.weight', 'backbone.bn1.bias', 'backbone.bn1.running_mean', … ‘neck.lateral_convs.0.conv.weight’ while in my model the weights will be called 'img_backbone.conv1.weight', 'img_backbone.bn1.weight' .. ‘img_neck.lateralconvs.0.conv.weight’ In other words, a 'img' prefix needs to be added to all of the state_dict keys.

My question is: Is there a way to load the weights from your checkpoint to my model without writing specific code that iterates over all of the state_dict keys and copying to the corresponding keys in my model ? (something similar to what is done in https://github.com/open-mmlab/mmdetection3d/blob/master/tools/model_converters/regnet2mmdet.py and in https://github.com/open-mmlab/mmdetection3d/blob/master/tools/model_converters/convert_votenet_checkpoints.py )

  1. As part of multi-stage training, I first want to train part of the model, and then load the trained part into a larger model. If, while training the larger model, I use load_from and provide a checkpoint for trained part of the model, will it work (mmdetection3d will take the weights from the checkpoint file for the first part of the model and initialize the rest of the model in the usual way (e.g. using Kaiming initialization)) ?

Thanks,

Yaniv

Tai-Wang commented 2 years ago

Sorry for the late reply. You can try to set load_from in the config to achieve your goal and check the initialization in the log file.

yaniv-f commented 2 years ago

Hi Tai-Wang. Thanks for your feedback ! So in summary, for question #1 I guess I can iterate over the state_dict and transfer the weights to the corresponding part of the new mode and for question #2 you confirm that load_from on part of the model will work. Thanks and best regards,

Yaniv