microsoft / nni

An open source AutoML toolkit for automate machine learning lifecycle, including feature engineering, neural architecture search, model compression and hyper-parameter tuning.
https://nni.readthedocs.io
MIT License
14.02k stars 1.81k forks source link

assert len(set(num_channels_list)) == 1 #4583

Open HappyPeanuts opened 2 years ago

HappyPeanuts commented 2 years ago

Describe the issue: when i try to speed up my model, i meet such a problem: /home/user/.conda/envs/yang_pytorch/lib/python3.7/site-packages/nni/compression/pytorch/utils/mask_conflict.py:124: UserWarning: This overload of nonzero is deprecated: nonzero() Consider using one of the following signatures instead: nonzero(*, bool as_tuple) (Triggered internally at /pytorch/torch/csrc/utils/python_arg_parser.cpp:882.) all_ones = (w_mask.flatten(1).sum(-1) == count).nonzero().squeeze(1).tolist() [2022-02-24 15:16:57] INFO (FixMaskConflict/MainThread) dim0 sparsity: 0.395664 [2022-02-24 15:16:57] INFO (FixMaskConflict/MainThread) dim1 sparsity: 0.000000 [2022-02-24 15:16:57] INFO (FixMaskConflict/MainThread) Dectected conv prune dim" 0 Traceback (most recent call last): File "pruning.py", line 117, in pruning.function() File "pruning.py", line 74, in function m_speedup.speedup_model() File "/home/user/.conda/envs/yang_pytorch/lib/python3.7/site-packages/nni/compression/pytorch/speedup/compressor.py", line 506, in speedup_model fix_mask_conflict(self.masks, self.bound_model, self.dummy_input) File "/home/user/.conda/envs/yang_pytorch/lib/python3.7/site-packages/nni/compression/pytorch/utils/mask_conflict.py", line 54, in fix_mask_conflict masks = fix_channel_mask.fix_mask() File "/home/user/.conda/envs/yang_pytorch/lib/python3.7/site-packages/nni/compression/pytorch/utils/mask_conflict.py", line 263, in fix_mask assert len(set(num_channels_list)) == 1 AssertionError so. i want to know how to resolve it ~ thanks

Environment:

Log message: [2022-02-24 15:16:55] INFO (nni.compression.pytorch.speedup.compressor/MainThread) start to speed up the model [2022-02-24 15:16:57] INFO (FixMaskConflict/MainThread) {'base.base_layer.0': 1, 'base.level0.0': 1, 'base.level1.0': 1, 'base.level2.project.0': 1, 'base.level2.tree1.squeeze_conv1x1_1': 1, 'base.level2.tree1.expand_conv1x1_1': 1, 'base.level2.tree1.expand_conv3x3_1': 1, 'base.level2.tree1.squeeze_conv1x1_2': 1, 'base.level2.tree1.expand_conv1x1_2': 1, 'base.level2.tree1.expand_conv3x3_2': 1, 'base.level2.tree2.squeeze_conv1x1_1': 1, 'base.level2.tree2.expand_conv1x1_1': 1, 'base.level2.tree2.expand_conv3x3_1': 1, 'base.level2.tree2.squeeze_conv1x1_2': 1, 'base.level2.tree2.expand_conv1x1_2': 1, 'base.level2.tree2.expand_conv3x3_2': 1, 'base.level2.root.conv': 1, 'base.level3.tree1.project.0': 1, 'base.level3.tree1.tree1.squeeze_conv1x1_1': 1, 'base.level3.tree1.tree1.expand_conv1x1_1': 1, 'base.level3.tree1.tree1.expand_conv3x3_1': 1, 'base.level3.tree1.tree1.squeeze_conv1x1_2': 1, 'base.level3.tree1.tree1.expand_conv1x1_2': 1, 'base.level3.tree1.tree1.expand_conv3x3_2': 1, 'base.level3.tree1.tree2.squeeze_conv1x1_1': 1, 'base.level3.tree1.tree2.expand_conv1x1_1': 1, 'base.level3.tree1.tree2.expand_conv3x3_1': 1, 'base.level3.tree1.tree2.squeeze_conv1x1_2': 1, 'base.level3.tree1.tree2.expand_conv1x1_2': 1, 'base.level3.tree1.tree2.expand_conv3x3_2': 1, 'base.level3.tree1.root.conv': 1, 'base.level3.tree2.tree1.squeeze_conv1x1_1': 1, 'base.level3.tree2.tree1.expand_conv1x1_1': 1, 'base.level3.tree2.tree1.expand_conv3x3_1': 1, 'base.level3.tree2.tree1.squeeze_conv1x1_2': 1, 'base.level3.tree2.tree1.expand_conv1x1_2': 1, 'base.level3.tree2.tree1.expand_conv3x3_2': 1, 'base.level3.tree2.tree2.squeeze_conv1x1_1': 1, 'base.level3.tree2.tree2.expand_conv1x1_1': 1, 'base.level3.tree2.tree2.expand_conv3x3_1': 1, 'base.level3.tree2.tree2.squeeze_conv1x1_2': 1, 'base.level3.tree2.tree2.expand_conv1x1_2': 1, 'base.level3.tree2.tree2.expand_conv3x3_2': 1, 'base.level3.tree2.root.conv': 1, 'base.level4.tree1.project.0': 1, 'base.level4.tree1.tree1.squeeze_conv1x1_1': 1, 'base.level4.tree1.tree1.expand_conv1x1_1': 1, 'base.level4.tree1.tree1.expand_conv3x3_1': 1, 'base.level4.tree1.tree1.squeeze_conv1x1_2': 1, 'base.level4.tree1.tree1.expand_conv1x1_2': 1, 'base.level4.tree1.tree1.expand_conv3x3_2': 1, 'base.level4.tree1.tree2.squeeze_conv1x1_1': 1, 'base.level4.tree1.tree2.expand_conv1x1_1': 1, 'base.level4.tree1.tree2.expand_conv3x3_1': 1, 'base.level4.tree1.tree2.squeeze_conv1x1_2': 1, 'base.level4.tree1.tree2.expand_conv1x1_2': 1, 'base.level4.tree1.tree2.expand_conv3x3_2': 1, 'base.level4.tree1.root.conv': 1, 'base.level4.tree2.tree1.squeeze_conv1x1_1': 1, 'base.level4.tree2.tree1.expand_conv1x1_1': 1, 'base.level4.tree2.tree1.expand_conv3x3_1': 1, 'base.level4.tree2.tree1.squeeze_conv1x1_2': 1, 'base.level4.tree2.tree1.expand_conv1x1_2': 1, 'base.level4.tree2.tree1.expand_conv3x3_2': 1, 'base.level4.tree2.tree2.squeeze_conv1x1_1': 1, 'base.level4.tree2.tree2.expand_conv1x1_1': 1, 'base.level4.tree2.tree2.expand_conv3x3_1': 1, 'base.level4.tree2.tree2.squeeze_conv1x1_2': 1, 'base.level4.tree2.tree2.expand_conv1x1_2': 1, 'base.level4.tree2.tree2.expand_conv3x3_2': 1, 'base.level4.tree2.root.conv': 1, 'base.self_attention.query_conv': 1, 'base.self_attention.key_conv': 1, 'base.self_attention.value_conv': 1, 'base.level5.project.0': 1, 'base.level5.tree1.squeeze_conv1x1_1': 1, 'base.level5.tree1.expand_conv1x1_1': 1, 'base.level5.tree1.expand_conv3x3_1': 1, 'base.level5.tree1.squeeze_conv1x1_2': 1, 'base.level5.tree1.expand_conv1x1_2': 1, 'base.level5.tree1.expand_conv3x3_2': 1, 'base.level5.tree2.squeeze_conv1x1_1': 1, 'base.level5.tree2.expand_conv1x1_1': 1, 'base.level5.tree2.expand_conv3x3_1': 1, 'base.level5.tree2.squeeze_conv1x1_2': 1, 'base.level5.tree2.expand_conv1x1_2': 1, 'base.level5.tree2.expand_conv3x3_2': 1, 'base.level5.root.conv': 1, 'dla_up.ida_0.proj_1.conv.conv_offset_mask': 1, 'dla_up.ida_0.up_1': 1, 'dla_up.ida_0.node_1.conv.conv_offset_mask': 1, 'dla_up.ida_1.proj_1.conv.conv_offset_mask': 1, 'dla_up.ida_1.up_1': 1, 'dla_up.ida_1.node_1.conv.conv_offset_mask': 1, 'dla_up.ida_1.proj_2.conv.conv_offset_mask': 1, 'dla_up.ida_1.up_2': 1, 'dla_up.ida_1.node_2.conv.conv_offset_mask': 1, 'dla_up.ida_2.proj_1.conv.conv_offset_mask': 1, 'dla_up.ida_2.up_1': 1, 'dla_up.ida_2.node_1.conv.conv_offset_mask': 1, 'dla_up.ida_2.proj_2.conv.conv_offset_mask': 1, 'dla_up.ida_2.up_2': 1, 'dla_up.ida_2.node_2.conv.conv_offset_mask': 1, 'dla_up.ida_2.proj_3.conv.conv_offset_mask': 1, 'dla_up.ida_2.up_3': 1, 'dla_up.ida_2.node_3.conv.conv_offset_mask': 1, 'ida_up.proj_1.conv.conv_offset_mask': 1, 'ida_up.up_1': 1, 'ida_up.node_1.conv.conv_offset_mask': 1, 'ida_up.proj_2.conv.conv_offset_mask': 1, 'ida_up.up_2': 1, 'ida_up.node_2.conv.conv_offset_mask': 1, 'hm.0': 1, 'hm.2': 1, 'hm.4': 1, 'wh.0': 1, 'wh.2': 1, 'wh.4': 1, 'reg.0': 1, 'reg.2': 1, 'reg.4': 1, 'id.0': 1, 'id.2': 1} /home/user/.conda/envs/yang_pytorch/lib/python3.7/site-packages/nni/compression/pytorch/utils/mask_conflict.py:124: UserWarning: This overload of nonzero is deprecated: nonzero() Consider using one of the following signatures instead: nonzero(*, bool as_tuple) (Triggered internally at /pytorch/torch/csrc/utils/python_arg_parser.cpp:882.) all_ones = (w_mask.flatten(1).sum(-1) == count).nonzero().squeeze(1).tolist() [2022-02-24 15:16:57] INFO (FixMaskConflict/MainThread) dim0 sparsity: 0.395664 [2022-02-24 15:16:57] INFO (FixMaskConflict/MainThread) dim1 sparsity: 0.000000 [2022-02-24 15:16:57] INFO (FixMaskConflict/MainThread) Dectected conv prune dim" 0 Traceback (most recent call last): File "pruning.py", line 117, in pruning.function() File "pruning.py", line 74, in function m_speedup.speedup_model() File "/home/user/.conda/envs/yang_pytorch/lib/python3.7/site-packages/nni/compression/pytorch/speedup/compressor.py", line 506, in speedup_model fix_mask_conflict(self.masks, self.bound_model, self.dummy_input) File "/home/user/.conda/envs/yang_pytorch/lib/python3.7/site-packages/nni/compression/pytorch/utils/mask_conflict.py", line 54, in fix_mask_conflict masks = fix_channel_mask.fix_mask() File "/home/user/.conda/envs/yang_pytorch/lib/python3.7/site-packages/nni/compression/pytorch/utils/mask_conflict.py", line 263, in fix_mask assert len(set(num_channels_list)) == 1 AssertionError How to reproduce it?:

Lijiaoa commented 2 years ago

related issue #4160