open-mmlab / mmrazor

OpenMMLab Model Compression Toolbox and Benchmark.
https://mmrazor.readthedocs.io/en/latest/
Apache License 2.0
1.47k stars 227 forks source link

Missing keys after RTMDET knowledge distillation #650

Closed Ramzes30765 closed 2 months ago

Ramzes30765 commented 2 months ago

Describe the bug

I successfully performed distillation using the script ./mmyolo/configs/rtmdet/distillation/kd_l_rtmdet_x_neck_phones.py. After that, I am trying to load the checkpoint and I get an error.

unexpected key in source state_dict: architecture.backbone.stem.0.conv.weight, architecture.backbone.stem.0.bn.weight, architecture.backbone.stem.0.bn.bias, .... etc

missing keys in source state_dict: backbone.stem.0.conv.weight, backbone.stem.0.bn.weight, backbone.stem.0.bn.bias, backbone.stem.0.bn.running_mean, .... etc

To Reproduce

Environvent

print(mmdet.__version__)
print(mmcv.__version__)
print(mmyolo.__version__)
print(mmrazor.__version__)

3.0.0
2.0.1
0.6.0
1.0.0

My config files

rtmdet_l_phones.py

_base_ = ['../_base_/default_runtime.py', '../_base_/det_p5_tta.py']

# ========================Frequently modified parameters======================
# -----data related-----
data_root = 'data/phones/'
# Path of train annotation file
train_ann_file = 'annotations/train_phones.json'
train_data_prefix = 'train/'  # Prefix of train image path
# Path of val annotation file
val_ann_file = 'annotations/valid_phones.json'
val_data_prefix = 'valid/'  # Prefix of val image path
# Path of val annotation file
test_ann_file = 'annotations/test_phones.json'
test_data_prefix = 'test/'  # Prefix of val image path

num_classes = 1  # Number of classes for classification
class_name = ('phone', )
num_classes = len(class_name)
metainfo = dict(classes=class_name, palette=[(20, 220, 60)])

# Batch size of a single GPU during training
train_batch_size_per_gpu = 16
# Worker to pre-fetch data for each single GPU during training
train_num_workers = 10
# persistent_workers must be False if num_workers is 0.
persistent_workers = True

# -----train val related-----
# Base learning rate for optim_wrapper. Corresponding to 8xb16=64 bs
base_lr = 0.004
max_epochs = 300  # Maximum training epochs
# Change train_pipeline for final 20 epochs (stage 2)
num_epochs_stage2 = 20

model_test_cfg = dict(
    # The config of multi-label for multi-class prediction.
    multi_label=True,
    # The number of boxes before NMS
    nms_pre=30000,
    score_thr=0.001,  # Threshold to filter out boxes.
    nms=dict(type='nms', iou_threshold=0.65),  # NMS type and threshold
    max_per_img=300)  # Max number of detections of each image

# ========================Possible modified parameters========================
# -----data related-----
img_scale = (640, 640)  # width, height
# ratio range for random resize
random_resize_ratio_range = (0.1, 2.0)
# Cached images number in mosaic
mosaic_max_cached_images = 40
# Number of cached images in mixup
mixup_max_cached_images = 20
# Dataset type, this will be used to define the dataset
dataset_type = 'YOLOv5CocoDataset'
# Batch size of a single GPU during validation
val_batch_size_per_gpu = 35
# Worker to pre-fetch data for each single GPU during validation
val_num_workers = 10

# Config of batch shapes. Only on val.
batch_shapes_cfg = dict(
    type='BatchShapePolicy',
    batch_size=val_batch_size_per_gpu,
    img_size=img_scale[0],
    size_divisor=32,
    extra_pad_ratio=0.5)

# -----model related-----
# The scaling factor that controls the depth of the network structure
deepen_factor = 1.0
# The scaling factor that controls the width of the network structure
widen_factor = 1.0
# Strides of multi-scale prior box
strides = [8, 16, 32]

norm_cfg = dict(type='BN')  # Normalization config

# -----train val related-----
lr_start_factor = 1.0e-5
dsl_topk = 13  # Number of bbox selected in each level
loss_cls_weight = 1.0
loss_bbox_weight = 2.0
qfl_beta = 2.0  # beta of QualityFocalLoss
weight_decay = 0.05

# Save model checkpoint and validation intervals
save_checkpoint_intervals = 10
# validation intervals in stage 2
val_interval_stage2 = 1
# The maximum checkpoints to keep.
max_keep_ckpts = 3
# single-scale training is recommended to
# be turned on, which can speed up training.
env_cfg = dict(cudnn_benchmark=True)

# ===============================Unmodified in most cases====================
model = dict(
    type='YOLODetector',
    data_preprocessor=dict(
        type='YOLOv5DetDataPreprocessor',
        mean=[103.53, 116.28, 123.675],
        std=[57.375, 57.12, 58.395],
        bgr_to_rgb=False),
    backbone=dict(
        type='CSPNeXt',
        arch='P5',
        expand_ratio=0.5,
        deepen_factor=deepen_factor,
        widen_factor=widen_factor,
        channel_attention=True,
        norm_cfg=norm_cfg,
        act_cfg=dict(type='SiLU', inplace=True)),
    neck=dict(
        type='CSPNeXtPAFPN',
        deepen_factor=deepen_factor,
        widen_factor=widen_factor,
        in_channels=[256, 512, 1024],
        out_channels=256,
        num_csp_blocks=3,
        expand_ratio=0.5,
        norm_cfg=norm_cfg,
        act_cfg=dict(type='SiLU', inplace=True)),
    bbox_head=dict(
        type='RTMDetHead',
        head_module=dict(
            type='RTMDetSepBNHeadModule',
            num_classes=num_classes,
            in_channels=256,
            stacked_convs=2,
            feat_channels=256,
            norm_cfg=norm_cfg,
            act_cfg=dict(type='SiLU', inplace=True),
            share_conv=True,
            pred_kernel_size=1,
            featmap_strides=strides),
        prior_generator=dict(
            type='mmdet.MlvlPointGenerator', offset=0, strides=strides),
        bbox_coder=dict(type='DistancePointBBoxCoder'),
        loss_cls=dict(
            type='mmdet.QualityFocalLoss',
            use_sigmoid=True,
            beta=qfl_beta,
            loss_weight=loss_cls_weight),
        loss_bbox=dict(type='mmdet.GIoULoss', loss_weight=loss_bbox_weight)),
    train_cfg=dict(
        assigner=dict(
            type='BatchDynamicSoftLabelAssigner',
            num_classes=num_classes,
            topk=dsl_topk,
            iou_calculator=dict(type='mmdet.BboxOverlaps2D')),
        allowed_border=-1,
        pos_weight=-1,
        debug=False),
    test_cfg=model_test_cfg,
)

train_pipeline = [
    dict(type='LoadImageFromFile', backend_args=_base_.backend_args),
    dict(type='LoadAnnotations', with_bbox=True),
    dict(
        type='Mosaic',
        img_scale=img_scale,
        use_cached=True,
        max_cached_images=mosaic_max_cached_images,
        pad_val=114.0),
    dict(
        type='mmdet.RandomResize',
        # img_scale is (width, height)
        scale=(img_scale[0] * 2, img_scale[1] * 2),
        ratio_range=random_resize_ratio_range,
        resize_type='mmdet.Resize',
        keep_ratio=True),
    dict(type='mmdet.RandomCrop', crop_size=img_scale),
    dict(type='mmdet.YOLOXHSVRandomAug'),
    dict(type='mmdet.RandomFlip', prob=0.5),
    dict(type='mmdet.Pad', size=img_scale, pad_val=dict(img=(114, 114, 114))),
    dict(
        type='YOLOv5MixUp',
        use_cached=True,
        max_cached_images=mixup_max_cached_images),
    dict(type='mmdet.PackDetInputs')
]

train_pipeline_stage2 = [
    dict(type='LoadImageFromFile', backend_args=_base_.backend_args),
    dict(type='LoadAnnotations', with_bbox=True),
    dict(
        type='mmdet.RandomResize',
        scale=img_scale,
        ratio_range=random_resize_ratio_range,
        resize_type='mmdet.Resize',
        keep_ratio=True),
    dict(type='mmdet.RandomCrop', crop_size=img_scale),
    dict(type='mmdet.YOLOXHSVRandomAug'),
    dict(type='mmdet.RandomFlip', prob=0.5),
    dict(type='mmdet.Pad', size=img_scale, pad_val=dict(img=(114, 114, 114))),
    dict(type='mmdet.PackDetInputs')
]

test_pipeline = [
    dict(type='LoadImageFromFile', backend_args=_base_.backend_args),
    dict(type='YOLOv5KeepRatioResize', scale=img_scale),
    dict(
        type='LetterResize',
        scale=img_scale,
        allow_scale_up=False,
        pad_val=dict(img=114)),
    dict(type='LoadAnnotations', with_bbox=True, _scope_='mmdet'),
    dict(
        type='mmdet.PackDetInputs',
        meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
                   'scale_factor', 'pad_param'))
]

train_dataloader = dict(
    batch_size=train_batch_size_per_gpu,
    num_workers=train_num_workers,
    persistent_workers=persistent_workers,
    pin_memory=True,
    collate_fn=dict(type='yolov5_collate'),
    sampler=dict(type='DefaultSampler', shuffle=True),
    dataset=dict(
        type=dataset_type,
        data_root=data_root,
        metainfo=metainfo,
        ann_file=train_ann_file,
        data_prefix=dict(img=train_data_prefix),
        filter_cfg=dict(filter_empty_gt=True, min_size=32),
        pipeline=train_pipeline))

val_dataloader = dict(
    batch_size=val_batch_size_per_gpu,
    num_workers=val_num_workers,
    persistent_workers=persistent_workers,
    pin_memory=True,
    drop_last=False,
    sampler=dict(type='DefaultSampler', shuffle=False),
    dataset=dict(
        type=dataset_type,
        data_root=data_root,
        metainfo=metainfo,
        ann_file=val_ann_file,
        data_prefix=dict(img=val_data_prefix),
        test_mode=True,
        batch_shapes_cfg=batch_shapes_cfg,
        pipeline=test_pipeline))

test_dataloader = dict(
    batch_size=val_batch_size_per_gpu,
    num_workers=val_num_workers,
    persistent_workers=persistent_workers,
    pin_memory=True,
    drop_last=False,
    sampler=dict(type='DefaultSampler', shuffle=False),
    dataset=dict(
        type=dataset_type,
        data_root=data_root,
        metainfo=metainfo,
        ann_file=test_ann_file,
        data_prefix=dict(img=test_data_prefix),
        test_mode=True,
        batch_shapes_cfg=batch_shapes_cfg,
        pipeline=test_pipeline))

# test_dataloader = val_dataloader

# Reduce evaluation time
val_evaluator = dict(
    type='mmdet.CocoMetric',
    proposal_nums=(100, 1, 10),
    ann_file=data_root + val_ann_file,
    metric='bbox')

test_evaluator = dict(
    type='mmdet.CocoMetric',
    proposal_nums=(100, 1, 10),
    ann_file=data_root + test_ann_file,
    metric='bbox')
# test_evaluator = val_evaluator

# optimizer
optim_wrapper = dict(
    type='OptimWrapper',
    optimizer=dict(type='AdamW', lr=base_lr, weight_decay=weight_decay),
    paramwise_cfg=dict(
        norm_decay_mult=0, bias_decay_mult=0, bypass_duplicate=True))

# learning rate
param_scheduler = [
    dict(
        type='LinearLR',
        start_factor=lr_start_factor,
        by_epoch=False,
        begin=0,
        end=1000),
    dict(
        # use cosine lr from 150 to 300 epoch
        type='CosineAnnealingLR',
        eta_min=base_lr * 0.05,
        begin=max_epochs // 2,
        end=max_epochs,
        T_max=max_epochs // 2,
        by_epoch=True,
        convert_to_iter_based=True),
]

# hooks
default_hooks = dict(
    checkpoint=dict(
        type='CheckpointHook',
        save_best='coco/bbox_mAP',
        rule="greater",
        interval=save_checkpoint_intervals,
        max_keep_ckpts=max_keep_ckpts  # only keep latest 3 checkpoints
    ),
    early_stopping=dict(
        type="EarlyStoppingHook",
        monitor="coco/bbox_mAP",
        patience=20,
        min_delta=0.005
    ),
)

custom_hooks = [
    dict(
        type='EMAHook',
        ema_type='ExpMomentumEMA',
        momentum=0.0002,
        update_buffers=True,
        strict_load=False,
        priority=49),
    dict(
        type='mmdet.PipelineSwitchHook',
        switch_epoch=max_epochs - num_epochs_stage2,
        switch_pipeline=train_pipeline_stage2)
]

log_config = dict(
    interval=100,
    hooks=[
        dict(
            type='ClearMLLoggerHook',
            init_kwargs=dict(
                project_name='mmdet_phones',
                task_name='OpenMMLab rmdet_l',
                output_uri=True
            )
        ),
    ]
)

train_cfg = dict(
    type='EpochBasedTrainLoop',
    max_epochs=max_epochs,
    val_interval=save_checkpoint_intervals,
    dynamic_intervals=[(max_epochs - num_epochs_stage2, val_interval_stage2)])

val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')

rtmdet_x_phones.py

_base_ = './rtmdet_l_phones.py'

# ========================Frequently modified parameters======================
# -----data related-----
data_root = 'data/phones/'
# Path of train annotation file
train_ann_file = 'annotations/train_phones.json'
train_data_prefix = 'train/'  # Prefix of train image path
# Path of val annotation file
val_ann_file = 'annotations/valid_phones.json'
val_data_prefix = 'valid/'  # Prefix of val image path
# Path of val annotation file
test_ann_file = 'annotations/test_phones.json'
test_data_prefix = 'test/'  # Prefix of val image path

num_classes = 1  # Number of classes for classification
class_name = ('phone', )
num_classes = len(class_name)
metainfo = dict(classes=class_name, palette=[(20, 220, 60)])

# ========================modified parameters======================
deepen_factor = 1.33
widen_factor = 1.25

# =======================Unmodified in most cases==================
model = dict(
    backbone=dict(deepen_factor=deepen_factor, widen_factor=widen_factor),
    neck=dict(deepen_factor=deepen_factor, widen_factor=widen_factor),
    bbox_head=dict(head_module=dict(widen_factor=widen_factor)))

kd_l_rtmdet_x_neck_phones.py

_base_ = '../rtmdet_l_phones.py'

teacher_ckpt = 'work_dirs/rtmdet_x_phones/epoch_300.pth'

norm_cfg = dict(type='BN', affine=False, track_running_stats=False)

model = dict(
    _delete_=True,
    _scope_='mmrazor',
    type='FpnTeacherDistill',
    architecture=dict(
        cfg_path='mmyolo::rtmdet/rtmdet_l_phones.py'),
    teacher=dict(
        cfg_path='mmyolo::rtmdet/rtmdet_x_phones.py'),
    teacher_ckpt=teacher_ckpt,
    distiller=dict(
        type='ConfigurableDistiller',
        # `recorders` are used to record various intermediate results during
        # the model forward.
        student_recorders=dict(
            fpn0=dict(type='ModuleOutputs', source='neck.out_layers.0.conv'),
            fpn1=dict(type='ModuleOutputs', source='neck.out_layers.1.conv'),
            fpn2=dict(type='ModuleOutputs', source='neck.out_layers.2.conv'),
        ),
        teacher_recorders=dict(
            fpn0=dict(type='ModuleOutputs', source='neck.out_layers.0.conv'),
            fpn1=dict(type='ModuleOutputs', source='neck.out_layers.1.conv'),
            fpn2=dict(type='ModuleOutputs', source='neck.out_layers.2.conv')),
        # `connectors` are adaptive layers which usually map teacher's and
        # students features to the same dimension.
        connectors=dict(
            fpn0_s=dict(
                type='ConvModuleConnector',
                in_channel=256,
                out_channel=320,
                bias=False,
                norm_cfg=norm_cfg,
                act_cfg=None),
            fpn0_t=dict(
                type='NormConnector', in_channels=320, norm_cfg=norm_cfg),
            fpn1_s=dict(
                type='ConvModuleConnector',
                in_channel=256,
                out_channel=320,
                bias=False,
                norm_cfg=norm_cfg,
                act_cfg=None),
            fpn1_t=dict(
                type='NormConnector', in_channels=320, norm_cfg=norm_cfg),
            fpn2_s=dict(
                type='ConvModuleConnector',
                in_channel=256,
                out_channel=320,
                bias=False,
                norm_cfg=norm_cfg,
                act_cfg=None),
            fpn2_t=dict(
                type='NormConnector', in_channels=320, norm_cfg=norm_cfg)),
        distill_losses=dict(
            loss_fpn0=dict(type='ChannelWiseDivergence', loss_weight=1),
            loss_fpn1=dict(type='ChannelWiseDivergence', loss_weight=1),
            loss_fpn2=dict(type='ChannelWiseDivergence', loss_weight=1)),
        # `loss_forward_mappings` are mappings between distill loss forward
        # arguments and records.
        loss_forward_mappings=dict(
            loss_fpn0=dict(
                preds_S=dict(
                    from_student=True, recorder='fpn0', connector='fpn0_s'),
                preds_T=dict(
                    from_student=False, recorder='fpn0', connector='fpn0_t')),
            loss_fpn1=dict(
                preds_S=dict(
                    from_student=True, recorder='fpn1', connector='fpn1_s'),
                preds_T=dict(
                    from_student=False, recorder='fpn1', connector='fpn1_t')),
            loss_fpn2=dict(
                preds_S=dict(
                    from_student=True, recorder='fpn2', connector='fpn2_s'),
                preds_T=dict(
                    from_student=False, recorder='fpn2',
                    connector='fpn2_t')))))

find_unused_parameters = True

custom_hooks = [
    dict(
        type='EMAHook',
        ema_type='ExpMomentumEMA',
        momentum=0.0002,
        update_buffers=True,
        strict_load=False,
        priority=49),
    dict(
        type='mmdet.PipelineSwitchHook',
        switch_epoch=_base_.max_epochs - _base_.num_epochs_stage2,
        switch_pipeline=_base_.train_pipeline_stage2),
    # stop distillation after the 280th epoch
    dict(type='mmrazor.StopDistillHook', stop_epoch=280)
]

Loading the model

config_file = './mmyolo/configs/rtmdet/rtmdet_l_phones.py'

checkpoint_file = './mmyolo/work_dirs/kd_l_rtmdet_x_neck_phones/best_coco_bbox_mAP_epoch_297.pth'

model = init_detector(config_file, checkpoint_file, device='cuda:2')

Training comand:

CUDA_VISIBLE_DEVICES=3 python tools/train.py configs/rtmdet/distillation/kd_l_rtmdet_x_neck_phones.py
Ramzes30765 commented 2 months ago

problem resolved by another model loading code

inferencer = DetInferencer(
    model='/mmyolo/configs/rtmdet/distillation/kd_l_rtmdet_x_neck_phones.py',
    weights='mmyolo/work_dirs/kd_l_rtmdet_x_neck_phones/best_coco_bbox_mAP_epoch_297.pth',
    scope='mmyolo',
    device='cuda:2')