open-mmlab / mmrazor

OpenMMLab Model Compression Toolbox and Benchmark.
https://mmrazor.readthedocs.io/en/latest/
Apache License 2.0
1.46k stars 227 forks source link

[Bug] test_pruner.py run errors #171

Open smallccn opened 2 years ago

smallccn commented 2 years ago
  1. When run test_pruner.py with 0.3.1, below error occured.
python tests/test_models/test_pruner.py
Traceback (most recent call last):
  File "tests/test_models/test_pruner.py", line 339, in <module>
    test_ratio_pruner()
  File "tests/test_models/test_pruner.py", line 38, in test_ratio_pruner
    _test_reset_bn_running_stats(architecture_cfg, pruner_cfg, False)
  File "tests/test_models/test_pruner.py", line 315, in _test_reset_bn_running_stats
    pruner1.prepare_from_supernet(architecture1)
  File "/root/mmlab/mmrazor/mmrazor/models/pruners/ratio_pruning.py", line 49, in prepare_from_supernet
    super(RatioPruner, self).prepare_from_supernet(supernet)
  File "/root/mmlab/mmrazor/mmrazor/models/pruners/structure_pruning.py", line 185, in prepare_from_supernet
    var2module, norm_conv_links, visited)
  File "/root/mmlab/mmrazor/mmrazor/models/pruners/structure_pruning.py", line 732, in trace_norm_conv_links
    visited)
  File "/root/mmlab/mmrazor/mmrazor/models/pruners/structure_pruning.py", line 732, in trace_norm_conv_links
    visited)
  File "/root/mmlab/mmrazor/mmrazor/models/pruners/structure_pruning.py", line 732, in trace_norm_conv_links
    visited)
  [Previous line repeated 5 more times]
  File "/root/mmlab/mmrazor/mmrazor/models/pruners/structure_pruning.py", line 699, in trace_norm_conv_links
    conv_grad_fn = conv_grad_fn.next_functions[0][0]
AttributeError: 'NoneType' object has no attribute 'next_functions'
  1. After downgrade the version to v0.3.0, issue not reproduced. but it seems set_min_channel() not works(after set_min_channel, and export_subnet will get the original net info)
-----------------------------codes in test_pruner.py-------------------------------------------------

    model_cfg = dict(
        type='mmcls.ImageClassifier',
        backbone=dict(
            type='mmcls.ResNet',
            depth=18,
            num_stages=4,
            out_indices=(3, ),
            style='pytorch'),
        neck=dict(type='mmcls.GlobalAveragePooling'),
        head=dict(
            type='mmcls.LinearClsHead',
            num_classes=1000,
            in_channels=512,
            loss=dict(type='mmcls.CrossEntropyLoss', loss_weight=1.0),
            topk=(1, 5),
        ))

    architecture_cfg = dict(
        type='MMClsArchitecture',
        model=model_cfg,
    )

    pruner_cfg = dict(
        type='RatioPruner',
        ratios=[1 / 8, 2 / 8, 3 / 8, 4 / 8, 5 / 8, 6 / 8, 7 / 8, 1.0])

    _test_reset_bn_running_stats(architecture_cfg, pruner_cfg, False)
    with pytest.raises(AssertionError):
        _test_reset_bn_running_stats(architecture_cfg, pruner_cfg, True)

    imgs = torch.randn(16, 3, 224, 224)
    label = torch.randint(0, 1000, (16, ))

    architecture = ARCHITECTURES.build(architecture_cfg)
    pruner = PRUNERS.build(pruner_cfg)

    pruner.prepare_from_supernet(architecture)
    assert hasattr(pruner, 'channel_spaces')

    # test set_min_channel
    pruner_cfg_ = deepcopy(pruner_cfg)
    #pruner_cfg_['ratios'].insert(0, 0)
    pruner_ = PRUNERS.build(pruner_cfg_)
    architecture_ = ARCHITECTURES.build(architecture_cfg)
    pruner_.prepare_from_supernet(architecture_)
    #with pytest.raises(AssertionError):
        # Output channels should be a positive integer not zero
    pruner_.set_min_channel()

    subnet_dict = pruner_.export_subnet()
    print(subnet_dict)
    assert isinstance(subnet_dict, dict)
    pruner_.deploy_subnet(architecture, subnet_dict)
-----------------------------output for subnet_dict----------------------------------

{'backbone.conv1': {'in_channels': 3, 'raw_in_channels': 3, 'out_channels': 64, 'raw_out_channels': 64}, 'backbone.bn1': {'out_channels': 64, 'raw_out_channels': 64}, 'backbone.layer1.0.conv1': {'in_channels': 64, 'raw_in_channels': 64, 'out_channels': 64, 'raw_out_channels': 64}, 'backbone.layer1.0.bn1': {'out_channels': 64, 'raw_out_channels': 64}, 'backbone.layer1.0.conv2': {'in_channels': 64, 'raw_in_channels': 64, 'out_channels': 64, 'raw_out_channels': 64}, 'backbone.layer1.0.bn2': {'out_channels': 64, 'raw_out_channels': 64}, 'backbone.layer1.1.conv1': {'in_channels': 64, 'raw_in_channels': 64, 'out_channels': 64, 'raw_out_channels': 64}, 'backbone.layer1.1.bn1': {'out_channels': 64, 'raw_out_channels': 64}, 'backbone.layer1.1.conv2': {'in_channels': 64, 'raw_in_channels': 64, 'out_channels': 64, 'raw_out_channels': 64}, 'backbone.layer1.1.bn2': {'out_channels': 64, 'raw_out_channels': 64}, 'backbone.layer2.0.conv1': {'in_channels': 64, 'raw_in_channels': 64, 'out_channels': 128, 'raw_out_channels': 128}, 'backbone.layer2.0.bn1': {'out_channels': 128, 'raw_out_channels': 128}, 'backbone.layer2.0.conv2': {'in_channels': 128, 'raw_in_channels': 128, 'out_channels': 128, 'raw_out_channels': 128}, 'backbone.layer2.0.bn2': {'out_channels': 128, 'raw_out_channels': 128}, 'backbone.layer2.0.downsample.0': {'in_channels': 64, 'raw_in_channels': 64, 'out_channels': 128, 'raw_out_channels': 128}, 'backbone.layer2.0.downsample.1': {'out_channels': 128, 'raw_out_channels': 128}, 'backbone.layer2.1.conv1': {'in_channels': 128, 'raw_in_channels': 128, 'out_channels': 128, 'raw_out_channels': 128}, 'backbone.layer2.1.bn1': {'out_channels': 128, 'raw_out_channels': 128}, 'backbone.layer2.1.conv2': {'in_channels': 128, 'raw_in_channels': 128, 'out_channels': 128, 'raw_out_channels': 128}, 'backbone.layer2.1.bn2': {'out_channels': 128, 'raw_out_channels': 128}, 'backbone.layer3.0.conv1': {'in_channels': 128, 'raw_in_channels': 128, 'out_channels': 256, 'raw_out_channels': 256}, 'backbone.layer3.0.bn1': {'out_channels': 256, 'raw_out_channels': 256}, 'backbone.layer3.0.conv2': {'in_channels': 256, 'raw_in_channels': 256, 'out_channels': 256, 'raw_out_channels': 256}, 'backbone.layer3.0.bn2': {'out_channels': 256, 'raw_out_channels': 256}, 'backbone.layer3.0.downsample.0': {'in_channels': 128, 'raw_in_channels': 128, 'out_channels': 256, 'raw_out_channels': 256}, 'backbone.layer3.0.downsample.1': {'out_channels': 256, 'raw_out_channels': 256}, 'backbone.layer3.1.conv1': {'in_channels': 256, 'raw_in_channels': 256, 'out_channels': 256, 'raw_out_channels': 256}, 'backbone.layer3.1.bn1': {'out_channels': 256, 'raw_out_channels': 256}, 'backbone.layer3.1.conv2': {'in_channels': 256, 'raw_in_channels': 256, 'out_channels': 256, 'raw_out_channels': 256}, 'backbone.layer3.1.bn2': {'out_channels': 256, 'raw_out_channels': 256}, 'backbone.layer4.0.conv1': {'in_channels': 256, 'raw_in_channels': 256, 'out_channels': 512, 'raw_out_channels': 512}, 'backbone.layer4.0.bn1': {'out_channels': 512, 'raw_out_channels': 512}, 'backbone.layer4.0.conv2': {'in_channels': 512, 'raw_in_channels': 512, 'out_channels': 512, 'raw_out_channels': 512}, 'backbone.layer4.0.bn2': {'out_channels': 512, 'raw_out_channels': 512}, 'backbone.layer4.0.downsample.0': {'in_channels': 256, 'raw_in_channels': 256, 'out_channels': 512, 'raw_out_channels': 512}, 'backbone.layer4.0.downsample.1': {'out_channels': 512, 'raw_out_channels': 512}, 'backbone.layer4.1.conv1': {'in_channels': 512, 'raw_in_channels': 512, 'out_channels': 512, 'raw_out_channels': 512}, 'backbone.layer4.1.bn1': {'out_channels': 512, 'raw_out_channels': 512}, 'backbone.layer4.1.conv2': {'in_channels': 512, 'raw_in_channels': 512, 'out_channels': 512, 'raw_out_channels': 512}, 'backbone.layer4.1.bn2': {'out_channels': 512, 'raw_out_channels': 512}, 'head.fc': {'in_channels': 512, 'raw_in_channels': 512, 'out_channels': 1000, 'raw_out_channels': 1000}}
----------------------------------------------------------------------------------------
HIT-cwh commented 2 years ago

Hi! Thank you for your issue. Could you provide your pytorch version?

smallccn commented 2 years ago

Hi! Thank you for your issue. Could you provide your pytorch version?

All env info like this, please refer it.

Package            Version     Location
------------------ ----------- -------------------------------------
addict             2.4.0
attrs              21.4.0
brotlipy           0.7.0
certifi            2022.5.18.1
cffi               1.15.0
charset-normalizer 2.0.4
click              7.1.2
colorama           0.4.4
cryptography       37.0.1
cycler             0.11.0
Cython             0.29.30
fonttools          4.33.3
idna               3.3
importlib-metadata 4.11.4
iniconfig          1.1.1
kiwisolver         1.4.2
Markdown           3.3.7
matplotlib         3.5.2
mkl-fft            1.3.1
mkl-random         1.2.2
mkl-service        2.4.0
mmcls              0.23.0
mmcv-full          1.5.0
mmdet              2.24.1
mmrazor            0.3.1       /root/mmlab/mmrazor
model-index        0.1.11
numpy              1.21.5
opencv-python      4.5.5.64
openmim            0.1.5
ordered-set        4.1.0
packaging          21.3
pandas             1.3.5
Pillow             9.0.1
pip                21.2.2
pluggy             1.0.0
py                 1.11.0
pycocotools        2.0.4
pycparser          2.21
pyOpenSSL          22.0.0
pyparsing          3.0.9
PySocks            1.7.1
pytest             7.1.2
python-dateutil    2.8.2
pytz               2022.1
PyYAML             6.0
requests           2.27.1
setuptools         61.2.0
six                1.16.0
tabulate           0.8.9
terminaltables     3.1.10
tomli              2.0.1
torch              1.11.0
torchaudio         0.11.0
torchvision        0.12.0
typing_extensions  4.1.1
urllib3            1.26.9
wheel              0.37.1
yapf               0.32.0
zipp               3.8.0
HIT-cwh commented 2 years ago

I'm sorry for the inconvenience caused to you. The auto-trace in pruner goes wrong with pytorch 1.11.0 and we will fix the errors in codes as soon as possible. In order not to affect your work, could you temporarily not use pytorch 1.8.1, 1.10.2, and 1.11.0 ? All other versions should be ok.

smallccn commented 2 years ago

I'm sorry for the inconvenience caused to you. The auto-trace in pruner goes wrong with pytorch 1.11.0 and we will fix the errors in codes as soon as possible. In order not to affect your work, could you temporarily not use pytorch 1.8.1, 1.10.2, and 1.11.0 ? All other versions should be ok.

Thank you for your help, after downgrade pytorch to 1.10.1, everything is ok.