open-mmlab / mmrazor

OpenMMLab Model Compression Toolbox and Benchmark.
https://mmrazor.readthedocs.io/en/latest/
Apache License 2.0
1.41k stars 220 forks source link

The effect was not improved after distillation #276

Open xuhao-anhe opened 1 year ago

xuhao-anhe commented 1 year ago

Describe the question you meet

I use the CWD method,When resnet50 is used to distill resnet18, the training accuracy of the teacher's network is 80%, but the network accuracy after distillation is only 46%. What is the reason

Post related information

  1. The output of pip list | grep "mmcv\|mmrazor\|^torch" [here]
  2. Your config file if you modified it or created a new one.
_base_ = [
    '../../_base_/datasets/mmdet/coco_instance.py',
    '../../_base_/schedules/mmdet/schedule_1x.py',
    '../../_base_/mmdet_runtime.py'
]

# model settings
student = dict(
    type='mmdet.PointRend',
    backbone=dict(
        type='ResNet',
        depth=18,
        num_stages=4,
        out_indices=(0, 1, 2, 3),
        frozen_stages=1,
        norm_cfg=dict(type='BN', requires_grad=False),
        norm_eval=True,
        style='caffe',
        init_cfg=dict(
            type='Pretrained',
            checkpoint='https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_b16x8_cifar10_20210528-bd6371c8.pth')
    ),
    neck=dict(
        type='FPN',
        in_channels=[64, 128, 256, 512],
        out_channels=256,
        num_outs=5),
    rpn_head=dict(
        type='RPNHead',
        in_channels=256,
        feat_channels=256,
        anchor_generator=dict(
            type='AnchorGenerator',
            scales=[8],
            ratios=[0.5, 1.0, 2.0],
            strides=[4, 8, 16, 32, 64]),
        bbox_coder=dict(
            type='DeltaXYWHBBoxCoder',
            target_means=[.0, .0, .0, .0],
            target_stds=[1.0, 1.0, 1.0, 1.0]),
        loss_cls=dict(
            type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
        loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
    roi_head=dict(
        type='PointRendRoIHead',
        bbox_roi_extractor=dict(
            type='SingleRoIExtractor',
            roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
            out_channels=256,
            featmap_strides=[4, 8, 16, 32]),
        bbox_head=dict(
            type='Shared2FCBBoxHead',
            in_channels=256,
            fc_out_channels=1024,
            roi_feat_size=7,
            num_classes=4,
            bbox_coder=dict(
                type='DeltaXYWHBBoxCoder',
                target_means=[0., 0., 0., 0.],
                target_stds=[0.1, 0.1, 0.2, 0.2]),
            reg_class_agnostic=False,
            loss_cls=dict(
                type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
            loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
        mask_roi_extractor=dict(
            type='GenericRoIExtractor',
            aggregation='concat',
            roi_layer=dict(
                type='SimpleRoIAlign', output_size=14),
            out_channels=256,
            featmap_strides=[4]),
        mask_head=dict(
            type='CoarseMaskHead',
            num_fcs=2,
            in_channels=256,
            conv_out_channels=256,
            fc_out_channels=1024,
            num_classes=4,
            loss_mask=dict(
                type='CrossEntropyLoss', use_mask=True, loss_weight=1.0)),
        point_head=dict(
            type='MaskPointHead',
            num_fcs=3,
            in_channels=256,
            fc_channels=256,
            num_classes=4,
            coarse_pred_each_layer=True,
            loss_point=dict(
                type='CrossEntropyLoss', use_mask=True, loss_weight=1.0))),
    # model training and testing settings
    train_cfg=dict(
        rpn=dict(
            assigner=dict(
                type='MaxIoUAssigner',
                pos_iou_thr=0.7,
                neg_iou_thr=0.3,
                min_pos_iou=0.3,
                match_low_quality=True,
                ignore_iof_thr=-1),
            sampler=dict(
                type='RandomSampler',
                num=256,
                pos_fraction=0.5,
                neg_pos_ub=-1,
                add_gt_as_proposals=False),
            allowed_border=-1,
            pos_weight=-1,
            debug=False),
        rpn_proposal=dict(
            nms_pre=2000,
            max_per_img=1000,
            nms=dict(type='nms', iou_threshold=0.7),
            min_bbox_size=0),
        rcnn=dict(
            assigner=dict(
                type='MaxIoUAssigner',
                pos_iou_thr=0.5,
                neg_iou_thr=0.5,
                min_pos_iou=0.5,
                match_low_quality=True,
                ignore_iof_thr=-1),
            sampler=dict(
                type='RandomSampler',
                num=512,
                pos_fraction=0.25,
                neg_pos_ub=-1,
                add_gt_as_proposals=True),
            mask_size=7,
            num_points=14 * 14,
            oversample_ratio=3,
            importance_sample_ratio=0.75,
            pos_weight=-1,
            debug=False)),
    test_cfg=dict(
        rpn=dict(
            nms_pre=1000,  # 在nms之前保留的的得分最高的proposal数量
            max_per_img=1000,
            nms=dict(type='nms', iou_threshold=0.7),
            min_bbox_size=0),
        rcnn=dict(
            score_thr=0.05,
            nms=dict(type='nms', iou_threshold=0.5),
            max_per_img=100,
            mask_thr_binary=0.5,
            subdivision_steps=5,
            subdivision_num_points=28 * 28,
            scale_factor=2)))

checkpoint = '/media/jidong/code/xuhao/mmdetection/load/point_rend_r50_caffe_fpn_mstrain_3x_coco-e0ebb6b7.pth'

teacher = dict(
    type='mmdet.PointRend',
    init_cfg=dict(type='Pretrained', checkpoint=checkpoint),
    backbone=dict(
        type='ResNet',
        depth=50,
        num_stages=4,
        out_indices=(0, 1, 2, 3),
        frozen_stages=1,
        norm_cfg=dict(type='BN', requires_grad=False),
        norm_eval=True,
        style='caffe',
        init_cfg=dict(
            type='Pretrained',
            checkpoint='open-mmlab://detectron2/resnet50_caffe')),
    neck=dict(
        type='FPN',
        in_channels=[256, 512, 1024, 2048],
        out_channels=256,
        num_outs=5),
    rpn_head=dict(
        type='RPNHead',
        in_channels=256,
        feat_channels=256,
        anchor_generator=dict(
            type='AnchorGenerator',
            scales=[8],
            ratios=[0.5, 1.0, 2.0],
            strides=[4, 8, 16, 32, 64]),
        bbox_coder=dict(
            type='DeltaXYWHBBoxCoder',
            target_means=[.0, .0, .0, .0],
            target_stds=[1.0, 1.0, 1.0, 1.0]),
        loss_cls=dict(
            type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
        loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
    roi_head=dict(
        type='PointRendRoIHead',
        bbox_roi_extractor=dict(
            type='SingleRoIExtractor',
            roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
            out_channels=256,
            featmap_strides=[4, 8, 16, 32]),
        bbox_head=dict(
            type='Shared2FCBBoxHead',
            in_channels=256,
            fc_out_channels=1024,
            roi_feat_size=7,
            num_classes=4,
            bbox_coder=dict(
                type='DeltaXYWHBBoxCoder',
                target_means=[0., 0., 0., 0.],
                target_stds=[0.1, 0.1, 0.2, 0.2]),
            reg_class_agnostic=False,
            loss_cls=dict(
                type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
            loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
        mask_roi_extractor=dict(
            type='GenericRoIExtractor',
            aggregation='concat',
            roi_layer=dict(
                type='SimpleRoIAlign', output_size=14),
            out_channels=256,
            featmap_strides=[4]),
        mask_head=dict(
            type='CoarseMaskHead',
            num_fcs=2,
            in_channels=256,
            conv_out_channels=256,
            fc_out_channels=1024,
            num_classes=4,
            loss_mask=dict(
                type='CrossEntropyLoss', use_mask=True, loss_weight=1.0)),
        point_head=dict(
            type='MaskPointHead',
            num_fcs=3,
            in_channels=256,
            fc_channels=256,
            num_classes=4,
            coarse_pred_each_layer=True,
            loss_point=dict(
                type='CrossEntropyLoss', use_mask=True, loss_weight=1.0))),
    # model training and testing settings
    train_cfg=dict(
        rpn=dict(
            assigner=dict(
                type='MaxIoUAssigner',
                pos_iou_thr=0.7,
                neg_iou_thr=0.3,
                min_pos_iou=0.3,
                match_low_quality=True,
                ignore_iof_thr=-1),
            sampler=dict(
                type='RandomSampler',
                num=256,
                pos_fraction=0.5,
                neg_pos_ub=-1,
                add_gt_as_proposals=False),
            allowed_border=-1,
            pos_weight=-1,
            debug=False),
        rpn_proposal=dict(
            nms_pre=2000,
            max_per_img=1000,
            nms=dict(type='nms', iou_threshold=0.7),
            min_bbox_size=0),
        rcnn=dict(
            assigner=dict(
                type='MaxIoUAssigner',
                pos_iou_thr=0.5,
                neg_iou_thr=0.5,
                min_pos_iou=0.5,
                match_low_quality=True,
                ignore_iof_thr=-1),
            sampler=dict(
                type='RandomSampler',
                num=512,
                pos_fraction=0.25,
                neg_pos_ub=-1,
                add_gt_as_proposals=True),
            mask_size=7,
            num_points=14 * 14,
            oversample_ratio=3,
            importance_sample_ratio=0.75,
            pos_weight=-1,
            debug=False)),
    test_cfg=dict(
        rpn=dict(
            nms_pre=1000,  # 在nms之前保留的的得分最高的proposal数量
            max_per_img=1000,
            nms=dict(type='nms', iou_threshold=0.7),
            min_bbox_size=0),
        rcnn=dict(
            score_thr=0.05,
            # nms=dict(type='soft_nms', iou_threshold=0.5, min_score=0.05),
            nms=dict(type='nms', iou_threshold=0.5),
            max_per_img=100,
            mask_thr_binary=0.5,
            subdivision_steps=5,
            subdivision_num_points=28 * 28,
            scale_factor=2)))

algorithm = dict(
    type='GeneralDistill',
    architecture=dict(
        type='MMDetArchitecture',
        model=student,
    ),
    distiller=dict(
        type='SingleTeacherDistiller',
        teacher=teacher,
        teacher_trainable=False,
        components=[
            dict(
                student_module='rpn_head.rpn_reg',
                teacher_module='rpn_head.rpn_reg',
                losses=[
                    dict(
                        type='ChannelWiseDivergence',
                        name='loss_cwd_point_rend',
                        tau=1,
                        loss_weight=5,
                    )
                ])
        ]),
)

find_unused_parameters = True

# python tools/mmdet/train_mmdet.py configs/distill/cwd/pointrend.py --work-dir pointrend-18 --cfg-options algorithm.distiller.teacher.init_cfg.type=Pretrained algorithm.distiller.teacher.init_cfg.checkpoint=https://download.openmmlab.com/mmdetection/v2.0/point_rend/point_rend_r50_caffe_fpn_mstrain_3x_coco/point_rend_r50_caffe_fpn_mstrain_3x_coco-e0ebb6b7.pth
  1. Your train log file if you meet the problem during training. Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.311 Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=1000 ] = 0.460 Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=1000 ] = 0.350 Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=1000 ] = 0.613 Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=1000 ] = -1.000 Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=1000 ] = -1.000 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.615 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=300 ] = 0.615 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=1000 ] = 0.615 Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=1000 ] = 0.615 Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=1000 ] = -1.000 Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=1000 ] = -1.000

2022-09-02 09:53:28,087 - mmdet - INFO - Evaluating segm... 2022-09-02 09:53:30,272 - mmdet - INFO - Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.372 Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=1000 ] = 0.467 Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=1000 ] = 0.417 Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=1000 ] = 0.743 Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=1000 ] = -1.000 Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=1000 ] = -1.000 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.746 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=300 ] = 0.746 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=1000 ] = 0.746 Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=1000 ] = 0.746 Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=1000 ] = -1.000 Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=1000 ] = -1.000

  1. Other code you modified in the mmrazor folder. [here]
Davidgzx commented 1 year ago

ChannelWiseDivergence is a response-based kd for semantic segmentation. Please use kl_div for cls head, it should work. Besides, if you want to do distillation on Reg heads, you can try it with SmoothL1 Loss.

xuhao-anhe commented 1 year ago

ChannelWiseDivergence is a response-based kd for semantic segmentation. Please use kl_div for cls head, it should work. Besides, if you want to do distillation on Reg heads, you can try it with SmoothL1 Loss.

Thank you very much for your answer. I would like to ask if there is a config file for reference。

xuhao-anhe commented 1 year ago

ChannelWiseDivergence is a response-based kd for semantic segmentation. Please use kl_div for cls head, it should work. Besides, if you want to do distillation on Reg heads, you can try it with SmoothL1 Loss.

Hello, I'm changing the loss function of loss_bbox to SmoothL1 Loss,but the final effect is not improved. Is the modification made in loss_bbox.I look forward to your reply. Thank you very much.

smntjugithub commented 11 months ago

How to modify your configuration file, can you teach me?