open-mmlab / mmrazor

OpenMMLab Model Compression Toolbox and Benchmark.
https://mmrazor.readthedocs.io/en/latest/
Apache License 2.0
1.47k stars 227 forks source link

How to get the network structure after slimming? #294

Open VJeee opened 2 years ago

VJeee commented 2 years ago

Checklist

Describe the question you meet

I used the official configuration file to do the autoslim test on the cifar100 dataset. But after using split_checkpoint.py to split the retrained weight file, the size of the obtained weight file is the same. And when doing the test, it will report the error 'The model and loaded state dict do not match exactly' and 'AttributeError: 'MMDataParallel' object has no attribute 'CLASSES''.

Post related information

  1. The output of pip list | grep "mmcv\|mmrazor\|^torch" mmcv-full 1.5.0 torch 1.10.0 torchsummary 1.5.1 torchvision 0.11.1 mmrazor 0.3.1
  2. Your config file if you modified it or created a new one.
# autoslim_mbv2_supernet_8xb32_in100_cifar100_test.py
_base_ = [
    '/home/wenjie/PycharmProjects/mmrazor_demo/configs/_base_/datasets/mmcls/cifar100_bs32_autoslim.py',
    '/home/wenjie/PycharmProjects/mmrazor_demo/configs/_base_/schedules/mmcls/cifar100_bs2048_autoslim.py',
    '/home/wenjie/PycharmProjects/mmrazor_demo/configs/_base_/mmcls_runtime.py'
]

model = dict(
    type='mmcls.ImageClassifier',
    backbone=dict(type='MobileNetV2', widen_factor=1.5),
    neck=dict(type='GlobalAveragePooling'),
    head=dict(
        type='LinearClsHead',
        num_classes=100,
        in_channels=1920,
        loss=dict(
            type='LabelSmoothLoss',
            mode='original',
            label_smooth_val=0.1,
            loss_weight=1.0),
        topk=(1, 5),
    ))

algorithm = dict(
    type='AutoSlim',
    architecture=dict(type='MMClsArchitecture', model=model),
    distiller=dict(
        type='SelfDistiller',
        components=[
            dict(
                student_module='head.fc',
                teacher_module='head.fc',
                losses=[
                    dict(
                        type='KLDivergence',
                        name='loss_kd',
                        tau=1,
                        loss_weight=1,
                    )
                ]),
        ]),
    pruner=dict(
        type='RatioPruner',
        ratios=(2 / 12, 3 / 12, 4 / 12, 5 / 12, 6 / 12, 7 / 12, 8 / 12, 9 / 12,
                10 / 12, 11 / 12, 1.0)),
    retraining=False,
    bn_training_mode=True,
    input_shape=None)

runner = dict(type='EpochBasedRunner', max_epochs=50)

use_ddp_wrapper = True

# autoslim_mbv2_search_8xb32_in100_cifar100_test.py
_base_ = [
    './autoslim_mbv2_supernet_8xb32_in100_cifar100_test.py',
]

algorithm = dict(distiller=None, input_shape=(1, 32, 32))

searcher = dict(
    type='GreedySearcher',
    target_flops=[14000000, 13000000, 12000000],
    max_channel_bins=12,
    metrics='accuracy')

data = dict(samples_per_gpu=1024, workers_per_gpu=4)

# autoslim_mbv2_subnet_8xb32_in100_cifar100_test.py
_base_ = [
    './autoslim_mbv2_supernet_8xb32_in100_cifar100_test.py',
]

model = dict(
    head=dict(
        loss=dict(
            type='LabelSmoothLoss',
            mode='original',
            label_smooth_val=0.1,
            loss_weight=1.0)))

# FIXME: you may replace this with the channel_cfg searched by yourself
channel_cfg = [
    '/home/wenjie/PycharmProjects/mmrazor_demo/autoslim_test/search/subnet_13978536.yaml',  # noqa: E501
    '/home/wenjie/PycharmProjects/mmrazor_demo/autoslim_test/search/subnet_12989328.yaml',  # noqa: E501
    '/home/wenjie/PycharmProjects/mmrazor_demo/autoslim_test/search/subnet_11942370.yaml',  # noqa: E501
]

algorithm = dict(
    architecture=dict(type='MMClsArchitecture', model=model),
    distiller=None,
    retraining=True,
    bn_training_mode=False,
    channel_cfg=channel_cfg)

runner = dict(type='EpochBasedRunner', max_epochs=300)

find_unused_parameters = True

# cifar100_bs32_autoslim.py
# dataset settings
dataset_type = 'CIFAR100'
img_norm_cfg = dict(
    mean=[129.304, 124.070, 112.434],
    std=[68.170, 65.392, 70.418],
    to_rgb=False)
train_pipeline = [
    dict(type='RandomCrop', size=32, padding=4),
    dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
    dict(type='Normalize', **img_norm_cfg),
    dict(type='ImageToTensor', keys=['img']),
    dict(type='ToTensor', keys=['gt_label']),
    dict(type='Collect', keys=['img', 'gt_label'])
]
test_pipeline = [
    dict(type='Normalize', **img_norm_cfg),
    dict(type='ImageToTensor', keys=['img']),
    dict(type='Collect', keys=['img'])
]
data = dict(
    samples_per_gpu=64,
    workers_per_gpu=2,
    train=dict(
        type=dataset_type,
        data_prefix='data/cifar100',
        pipeline=train_pipeline),
    val=dict(
        type=dataset_type,
        data_prefix='data/cifar100',
        pipeline=test_pipeline,
        test_mode=True),
    test=dict(
        type=dataset_type,
        data_prefix='data/cifar100',
        pipeline=test_pipeline,
        test_mode=True))
evaluation = dict(interval=1, metric='accuracy')

# cifar100_bs2048_autoslim.py
# optimizer
paramwise_cfg = dict(
    bias_decay_mult=0.0, norm_decay_mult=0.0, dwconv_decay_mult=0.0)
optimizer = dict(
    type='SGD',
    lr=0.1,
    momentum=0.9,
    nesterov=True,
    weight_decay=0.0001,
    paramwise_cfg=paramwise_cfg)

optimizer_config = None

# learning policy
lr_config = dict(policy='poly', power=1.0, min_lr=0.0, by_epoch=False)
runner = dict(type='EpochBasedRunner', max_epochs=300)
  1. Your train log file if you meet the problem during training. [here]
  2. Other code you modified in the mmrazor folder. [here]
HIT-cwh commented 1 year ago

Hi! Sorry for the inconvenience to you. We split the checkpoint based on algorithm.pruner.deploy_subnet (refer to here). When overwriting the original parameter with the sliced parameter of a nn.Module, it is necessary to use the copy of the sliced parameter like module.weight = nn.Parameter(temp_weight.data.clone()). After that, the file size of the three different checkpoint sizes will be completely different. This problem is fixed in branch dev-1.x.

For the state dict mismatch problem, channel_cfg may be

channel_cfg = '/home/wenjie/PycharmProjects/mmrazor_demo/autoslim_test/search/subnet_13978536.yaml' 

but not

channel_cfg = [
    '/home/wenjie/PycharmProjects/mmrazor_demo/autoslim_test/search/subnet_13978536.yaml',  # noqa: E501
    '/home/wenjie/PycharmProjects/mmrazor_demo/autoslim_test/search/subnet_12989328.yaml',  # noqa: E501
    '/home/wenjie/PycharmProjects/mmrazor_demo/autoslim_test/search/subnet_11942370.yaml',  # noqa: E501
]

if the checkpoint to be loaded is the split one corresponding to subnet_13978536.

For the 'AttributeError: 'MMDataParallel' object has no attribute 'CLASSES'' problem, could you please provide more error information so that we can locate the bug.