microsoft / nni

An open source AutoML toolkit for automate machine learning lifecycle, including feature engineering, neural architecture search, model compression and hyper-parameter tuning.
https://nni.readthedocs.io
MIT License
13.88k stars 1.81k forks source link

ModelSpeedup error: assert len(set(num_channels_list)) == 1, possible incorrect layers in dependency set #5736

Open saravanabalagi opened 5 months ago

saravanabalagi commented 5 months ago

ModelSpeedup does not alter the model successfully for a model with 3 successive conv blocks.

Environment:

Reproduce the problem

Minimal Code ```python # %% import torch import torch.nn as nn from nni.compression.pruning import L1NormPruner from nni.compression.utils import auto_set_denpendency_group_ids from nni.compression.speedup import ModelSpeedup # %% class ConvNet(nn.Module): def __init__(self): super().__init__() self.conv1 = nn.Conv2d(3, 40, kernel_size=3, padding=1) self.bn1 = nn.BatchNorm2d(40) self.relu1 = nn.ReLU(inplace=True) self.conv2 = nn.Conv2d(40, 80, kernel_size=3, padding=1) self.bn2 = nn.BatchNorm2d(80) self.relu2 = nn.ReLU(inplace=True) self.conv3 = nn.Conv2d(80, 1, kernel_size=3, padding=1) self.bn3 = nn.BatchNorm2d(1) def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu1(x) x = self.conv2(x) x = self.bn2(x) x = self.relu2(x) x = self.conv3(x) x = self.bn3(x) return x model = ConvNet() num_params_unpruned = sum(p.numel() for p in model.parameters()) dummy_input = torch.randn(1, 3, 32, 32) dummy_output = model(dummy_input) print(dummy_output.shape) # %% sparsity_ratio = 0.5 config_list = [{ 'op_types': ['Conv2d'], 'sparse_ratio': sparsity_ratio, }] config_list = auto_set_denpendency_group_ids(model, config_list, [dummy_input]) pruner = L1NormPruner(model, config_list) _, masks = pruner.compress() pruner.unwrap_model() model = ModelSpeedup(model, [dummy_input], masks, garbage_collect_values=False).speedup_model() # %% num_params_pruned = sum(p.numel() for p in model.parameters()) print(f'Number of parameters before pruning: {num_params_unpruned}') print(f'Number of parameters after pruning: {num_params_pruned}') num_params_diff = num_params_unpruned - num_params_pruned prune_ratio = num_params_diff / num_params_unpruned print(f'Number of parameters pruned: {num_params_diff}') print(f'Parameter ratio: {(1-prune_ratio)*100:.2f}%') ```

Error:

Assertion error: number of channels in same set should be identical

Error Trace ``` --------------------------------------------------------------------------- AssertionError Traceback (most recent call last) Cell In[108], line 1 ----> 1 model = ModelSpeedup(model, [dummy_input], masks, garbage_collect_values=False).speedup_model() File /usr/local/lib/python3.8/dist-packages/nni/compression/speedup/model_speedup.py:429, in ModelSpeedup.speedup_model(self) 427 self.logger.info('Resolve the mask conflict before mask propagate...') 428 # fix_mask_conflict(self.masks, self.graph_module, self.dummy_input) --> 429 self.fix_mask_conflict() 430 self.logger.info('Infer module masks...') 431 self.initialize_propagate(self.dummy_input) File /usr/local/lib/python3.8/dist-packages/nni/compression/speedup/model_speedup.py:243, in ModelSpeedup.fix_mask_conflict(self) 241 def fix_mask_conflict(self): 242 fix_group_mask_conflict(self.graph_module, self.masks) --> 243 fix_channel_mask_conflict(self.graph_module, self.masks) 244 fix_weight_sharing_mask_conflict(self.graph_module, self.masks) File /usr/local/lib/python3.8/dist-packages/nni/compression/speedup/mask_conflict.py:296, in fix_channel_mask_conflict(graph_module, masks) 294 num_channels_list = [len(x) for x in channel_masks if x is not None] 295 # number of channels in same set should be identical --> 296 assert len(set(num_channels_list)) == 1 297 num_channels = num_channels_list[0] 299 for i, dim_mask in enumerate(channel_masks): AssertionError: ```

The same code works fine without self.conv3 and self.bn3.

saravanabalagi commented 5 months ago

The error is thrown specifically when the output channels of the last layer is 1, even when there are 2 successive conv blocks:

class ConvNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.relu = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(3, 6, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(6)
        self.conv2 = nn.Conv2d(6, 1, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(1)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        return x