microsoft / nni

An open source AutoML toolkit for automate machine learning lifecycle, including feature engineering, neural architecture search, model compression and hyper-parameter tuning.
https://nni.readthedocs.io
MIT License
14.06k stars 1.82k forks source link

Bug: batch_size parameter in ModelSpeedup does not alter the model when set to 1 #5735

Open saravanabalagi opened 10 months ago

saravanabalagi commented 10 months ago

batch_size parameter in ModelSpeedup does not prune or alter the model when set to 1.

Environment:

Reproduce the problem

Minimal Code ```python # %% import torch import torch.nn as nn from nni.compression.pruning import L1NormPruner from nni.compression.utils import auto_set_denpendency_group_ids from nni.compression.speedup import ModelSpeedup # %% class ConvNet(nn.Module): def __init__(self): super().__init__() self.conv1 = nn.Conv2d(3, 40, kernel_size=3, padding=1) self.bn1 = nn.BatchNorm2d(40) self.relu1 = nn.ReLU(inplace=True) self.conv2 = nn.Conv2d(40, 80, kernel_size=3, padding=1) self.bn2 = nn.BatchNorm2d(80) self.relu2 = nn.ReLU(inplace=True) def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu1(x) x = self.conv2(x) x = self.bn2(x) x = self.relu2(x) return x model = ConvNet() num_params_unpruned = sum(p.numel() for p in model.parameters()) dummy_input = torch.randn(1, 3, 32, 32) dummy_output = model(dummy_input) print(dummy_output.shape) # %% sparsity_ratio = 0.5 config_list = [{ 'op_types': ['Conv2d'], 'sparse_ratio': sparsity_ratio, }] config_list = auto_set_denpendency_group_ids(model, config_list, [dummy_input]) pruner = L1NormPruner(model, config_list) _, masks = pruner.compress() pruner.unwrap_model() model = ModelSpeedup(model, [dummy_input], masks, batch_size=1, garbage_collect_values=False).speedup_model() # %% num_params_pruned = sum(p.numel() for p in model.parameters()) print(f'Number of parameters before pruning: {num_params_unpruned}') print(f'Number of parameters after pruning: {num_params_pruned}') num_params_diff = num_params_unpruned - num_params_pruned prune_ratio = num_params_diff / num_params_unpruned print(f'Number of parameters pruned: {num_params_diff}') print(f'Parameter ratio: {(1-prune_ratio)*100:.2f}%') ```

Output:

Regardless of the sparsity ratio, the model is not altered

Number of parameters before pruning: 30240
Number of parameters after pruning: 30240
Number of parameters pruned: 0
Parameter ratio: 100.00%
Full Log ``` [2024-01-17 16:55:43] Start to speedup the model... [2024-01-17 16:55:43] Resolve the mask conflict before mask propagate... [2024-01-17 16:55:43] dim0 sparsity: 0.500000 [2024-01-17 16:55:43] dim1 sparsity: 0.000000 0 Filter [2024-01-17 16:55:43] dim0 sparsity: 0.500000 [2024-01-17 16:55:43] dim1 sparsity: 0.000000 [2024-01-17 16:55:43] Infer module masks... [2024-01-17 16:55:43] Propagate original variables [2024-01-17 16:55:43] Propagate variables for placeholder: x, output mask: 0.0000 [2024-01-17 16:55:43] Propagate variables for call_module: conv1, weight: 0.5000 bias: 0.5000 , output mask: 0.0000 [2024-01-17 16:55:44] Propagate variables for call_module: bn1, , output mask: 0.0000 [2024-01-17 16:55:44] Propagate variables for call_module: relu1, , output mask: 0.0000 [2024-01-17 16:55:44] Propagate variables for call_module: conv2, weight: 0.5000 bias: 0.5000 , output mask: 0.0000 [2024-01-17 16:55:44] Propagate variables for call_module: bn2, , output mask: 0.0000 [2024-01-17 16:55:44] Propagate variables for call_module: relu2, , output mask: 0.0000 [2024-01-17 16:55:44] Propagate variables for output: output, output mask: 0.0000 [2024-01-17 16:55:44] Update direct sparsity... [2024-01-17 16:55:45] Update direct mask for placeholder: x, output mask: 0.0000 [2024-01-17 16:55:45] Update direct mask for call_module: conv1, weight: 0.5000 bias: 0.5000 , output mask: 0.0000 [2024-01-17 16:55:45] Update direct mask for call_module: bn1, , output mask: 0.0000 [2024-01-17 16:55:45] Update direct mask for call_module: relu1, , output mask: 0.0000 [2024-01-17 16:55:45] Update direct mask for call_module: conv2, weight: 0.5000 bias: 0.5000 , output mask: 0.0000 [2024-01-17 16:55:45] Update direct mask for call_module: bn2, , output mask: 0.0000 [2024-01-17 16:55:45] Update direct mask for call_module: relu2, , output mask: 0.0000 [2024-01-17 16:55:46] Update direct mask for output: output, output mask: 0.0000 [2024-01-17 16:55:46] Update indirect sparsity... [2024-01-17 16:55:46] Update indirect mask for output: output, output mask: 0.0000 [2024-01-17 16:55:46] Update indirect mask for call_module: relu2, , output mask: 0.0000 [2024-01-17 16:55:46] Update indirect mask for call_module: bn2, , output mask: 0.0000 [2024-01-17 16:55:46] Update indirect mask for call_module: conv2, weight: 0.5000 bias: 0.5000 , output mask: 0.0000 [2024-01-17 16:55:47] Update indirect mask for call_module: relu1, , output mask: 0.0000 [2024-01-17 16:55:47] Update indirect mask for call_module: bn1, , output mask: 0.0000 [2024-01-17 16:55:47] Update indirect mask for call_module: conv1, weight: 0.5000 bias: 0.5000 , output mask: 0.0000 [2024-01-17 16:55:47] Update indirect mask for placeholder: x, output mask: 0.0000 [2024-01-17 16:55:47] Resolve the mask conflict after mask propagate... [2024-01-17 16:55:47] dim0 sparsity: 0.500000 [2024-01-17 16:55:47] dim1 sparsity: 0.000000 0 Filter [2024-01-17 16:55:47] dim0 sparsity: 0.500000 [2024-01-17 16:55:47] dim1 sparsity: 0.000000 [2024-01-17 16:55:47] Replace compressed modules... [2024-01-17 16:55:47] replace module (name: conv1, op_type: Conv2d) [2024-01-17 16:55:47] replace conv2d with in_channels: 3, out_channels: 40 [2024-01-17 16:55:47] replace module (name: bn1, op_type: BatchNorm2d) [2024-01-17 16:55:47] replace batchnorm2d with num_features: 40 [2024-01-17 16:55:47] replace module (name: relu1, op_type: ReLU) [2024-01-17 16:55:47] replace module (name: conv2, op_type: Conv2d) [2024-01-17 16:55:47] replace conv2d with in_channels: 40, out_channels: 80 [2024-01-17 16:55:47] replace module (name: bn2, op_type: BatchNorm2d) [2024-01-17 16:55:47] replace batchnorm2d with num_features: 80 [2024-01-17 16:55:47] replace module (name: relu2, op_type: ReLU) [2024-01-17 16:55:47] Speedup done. ```

Expected Output:

If any batch_size above 1 is specified, or if the argument is not specified, then model is altered as expected

Number of parameters before pruning: 30240
Number of parameters after pruning: 7920
Number of parameters pruned: 22320
Parameter ratio: 26.19%
Full Log ``` [2024-01-17 17:07:12] Start to speedup the model... [2024-01-17 17:07:12] Resolve the mask conflict before mask propagate... [2024-01-17 17:07:12] dim0 sparsity: 0.500000 [2024-01-17 17:07:12] dim1 sparsity: 0.000000 0 Filter [2024-01-17 17:07:12] dim0 sparsity: 0.500000 [2024-01-17 17:07:12] dim1 sparsity: 0.000000 [2024-01-17 17:07:12] Infer module masks... [2024-01-17 17:07:12] Propagate original variables [2024-01-17 17:07:12] Propagate variables for placeholder: x, output mask: 0.0000 [2024-01-17 17:07:12] Propagate variables for call_module: conv1, weight: 0.5000 bias: 0.5000 , output mask: 0.0000 [2024-01-17 17:07:12] Propagate variables for call_module: bn1, , output mask: 0.0000 [2024-01-17 17:07:12] Propagate variables for call_module: relu1, , output mask: 0.0000 [2024-01-17 17:07:12] Propagate variables for call_module: conv2, weight: 0.5000 bias: 0.5000 , output mask: 0.0000 [2024-01-17 17:07:13] Propagate variables for call_module: bn2, , output mask: 0.0000 [2024-01-17 17:07:13] Propagate variables for call_module: relu2, , output mask: 0.0000 [2024-01-17 17:07:13] Propagate variables for output: output, output mask: 0.0000 [2024-01-17 17:07:13] Update direct sparsity... [2024-01-17 17:07:13] Update direct mask for placeholder: x, output mask: 0.0000 [2024-01-17 17:07:13] Update direct mask for call_module: conv1, weight: 0.5000 bias: 0.5000 , output mask: 0.5000 [2024-01-17 17:07:13] Update direct mask for call_module: bn1, , output mask: 0.5000 [2024-01-17 17:07:13] Update direct mask for call_module: relu1, , output mask: 0.5000 [2024-01-17 17:07:14] Update direct mask for call_module: conv2, weight: 0.5000 bias: 0.5000 , output mask: 0.5000 [2024-01-17 17:07:14] Update direct mask for call_module: bn2, , output mask: 0.5000 [2024-01-17 17:07:14] Update direct mask for call_module: relu2, , output mask: 0.5000 [2024-01-17 17:07:14] Update direct mask for output: output, output mask: 0.5000 [2024-01-17 17:07:14] Update indirect sparsity... [2024-01-17 17:07:14] Update indirect mask for output: output, output mask: 0.5000 [2024-01-17 17:07:14] Update indirect mask for call_module: relu2, , output mask: 0.5000 [2024-01-17 17:07:15] Update indirect mask for call_module: bn2, , output mask: 0.5000 [2024-01-17 17:07:15] Update indirect mask for call_module: conv2, weight: 0.7500 bias: 0.5000 , output mask: 0.5000 [2024-01-17 17:07:15] Update indirect mask for call_module: relu1, , output mask: 0.5000 [2024-01-17 17:07:16] Update indirect mask for call_module: bn1, , output mask: 0.5000 [2024-01-17 17:07:16] Update indirect mask for call_module: conv1, weight: 0.5000 bias: 0.5000 , output mask: 0.5000 [2024-01-17 17:07:16] Update indirect mask for placeholder: x, output mask: 0.0000 [2024-01-17 17:07:16] Resolve the mask conflict after mask propagate... [2024-01-17 17:07:16] dim0 sparsity: 0.500000 [2024-01-17 17:07:16] dim1 sparsity: 0.465116 [2024-01-17 17:07:16] WARNING: both dim0 and dim1 masks found. 0 Filter [2024-01-17 17:07:16] dim0 sparsity: 0.500000 [2024-01-17 17:07:16] dim1 sparsity: 0.465116 [2024-01-17 17:07:16] WARNING: both dim0 and dim1 masks found. [2024-01-17 17:07:16] Replace compressed modules... [2024-01-17 17:07:16] replace module (name: conv1, op_type: Conv2d) [2024-01-17 17:07:16] replace conv2d with in_channels: 3, out_channels: 20 [2024-01-17 17:07:16] replace module (name: bn1, op_type: BatchNorm2d) [2024-01-17 17:07:16] replace batchnorm2d with num_features: 20 [2024-01-17 17:07:16] replace module (name: relu1, op_type: ReLU) [2024-01-17 17:07:16] replace module (name: conv2, op_type: Conv2d) [2024-01-17 17:07:16] replace conv2d with in_channels: 20, out_channels: 40 [2024-01-17 17:07:16] replace module (name: bn2, op_type: BatchNorm2d) [2024-01-17 17:07:16] replace batchnorm2d with num_features: 40 [2024-01-17 17:07:16] replace module (name: relu2, op_type: ReLU) [2024-01-17 17:07:16] Speedup done. ```
saravanabalagi commented 8 months ago

Direct update sparsity during speedup includes a step that calculates standard deviation of outputs at the batch dimension i.e., deviation compared with other batches of data is measured and all positions corresponding to less than a predefined epsilon (1e-6) are then used to fill mask positions. Thus, this process requires a batch size of at least 2 to compare and find std.

torch.std return NaN if only one sample is given (as it is impossible to find standard dev with only one sample).

saravanabalagi commented 8 months ago

While it makes sense to mask weights corresponding to the output positions where std dev is close to 0 for speedup, this is not mentioned and/or explained in the docs. This can also lead to removing params/channels when absolutely no masks are given as input to speedup to begin with.