microsoft / nni

An open source AutoML toolkit for automate machine learning lifecycle, including feature engineering, neural architecture search, model compression and hyper-parameter tuning.
https://nni.readthedocs.io
MIT License
14k stars 1.81k forks source link

Model speedup fails due to Attribute Error #5794

Open krteyu opened 3 months ago

krteyu commented 3 months ago

Describe the bug: I am attempting to run this pruned mirror detection model [https://github.com/memgonzales/mirror-segmentation], but when I attempt to perform speedup_model(), I get the following error:

  File "C:\Users\pillai.k\AppData\Local\Programs\Python\Python39\lib\runpy.py", line 197, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "C:\Users\pillai.k\AppData\Local\Programs\Python\Python39\lib\runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "C:\Users\pillai.k\PMDLite\mirror-segmentation-main\prune.py", line 178, in <module>
    main()
  File "C:\Users\pillai.k\PMDLite\mirror-segmentation-main\prune.py", line 127, in main
    ModelSpeedup(net, dummy, masks).speedup_model()
  File "C:\Users\pillai.k\MirrorEnv\lib\site-packages\nni\compression\speedup\model_speedup.py", line 429, in speedup_model
    self.fix_mask_conflict()
  File "C:\Users\pillai.k\MirrorEnv\lib\site-packages\nni\compression\speedup\model_speedup.py", line 243, in fix_mask_conflict
    fix_channel_mask_conflict(self.graph_module, self.masks)
  File "C:\Users\pillai.k\MirrorEnv\lib\site-packages\nni\compression\speedup\mask_conflict.py", line 229, in fix_channel_mask_conflict
    prune_axis = detect_mask_prune_dim(graph_module, masks)
  File "C:\Users\pillai.k\MirrorEnv\lib\site-packages\nni\compression\speedup\mask_conflict.py", line 400, in detect_mask_prune_dim
    sub_module = graph_module.get_submodule(layer_name)
  File "C:\Users\pillai.k\MirrorEnv\lib\site-packages\torch\nn\modules\module.py", line 686, in get_submodule
    raise AttributeError(mod._get_name() + " has no "
AttributeError: BFE_Module has no attribute `cbam`

This error occurs for the mask element edge_extract.cbam.ChannelGate.mlp.1. Given below is the model description for this module:

  (edge_extract): BFE_Module(
    (cbam): CBAM(
      (ChannelGate): ChannelGate(
        (mlp): Sequential(
          (0): Flatten()
          (1): Linear(in_features=56, out_features=3, bias=True)
          (2): ReLU(inplace=True)
          (3): Linear(in_features=3, out_features=56, bias=True)
        )
      )
      (SpatialGate): SpatialGate(
        (compress): ChannelPool()
        (spatial): BasicConv(
          (conv): Conv2d(1, 1, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), bias=False)
          (bn): BatchNorm2d(1, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
        )
      )
    )
    (edge_layer1): Sequential(
      (0): Conv2d(56, 28, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(28, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (edge_layer2): Sequential(
      (0): Conv2d(56, 28, kernel_size=(3, 3), stride=(1, 1))
      (1): BatchNorm2d(28, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (edge_layer3): Sequential(
      (0): Conv2d(56, 28, kernel_size=(3, 3), stride=(1, 1), dilation=(2, 2))
      (1): BatchNorm2d(28, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (edge_layer4): Sequential(
      (0): Conv2d(56, 28, kernel_size=(3, 3), stride=(1, 1), dilation=(4, 4))
      (1): BatchNorm2d(28, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
  ) 

Environment:

Reproduce the problem



- How to reproduce:
- set the path to pruned weights file and run `prune.py`
krteyu commented 3 months ago

This problem appears to be due to cbam being instantiated in edge extract but not being used in the forward function. Commenting out cbam appears to fix this issue.

However I am now getting an assert len(set(num_channels_list)) == 1 AssertionError similar to #4160 which is also open. I am not able to find the model dependency causing this issue, any help would be appreciated