Open saravanabalagi opened 10 months ago
Direct update sparsity during speedup includes a step that calculates standard deviation of outputs at the batch dimension i.e., deviation compared with other batches of data is measured and all positions corresponding to less than a predefined epsilon (1e-6) are then used to fill mask positions. Thus, this process requires a batch size of at least 2 to compare and find std.
torch.std return NaN if only one sample is given (as it is impossible to find standard dev with only one sample).
While it makes sense to mask weights corresponding to the output positions where std dev is close to 0 for speedup, this is not mentioned and/or explained in the docs. This can also lead to removing params/channels when absolutely no masks are given as input to speedup to begin with.
batch_size
parameter in ModelSpeedup does not prune or alter the model when set to 1.Environment:
Reproduce the problem
sparsity_ratio
L1NormPruner
ModelSpeedup
withbatch_size
parameterMinimal Code
```python # %% import torch import torch.nn as nn from nni.compression.pruning import L1NormPruner from nni.compression.utils import auto_set_denpendency_group_ids from nni.compression.speedup import ModelSpeedup # %% class ConvNet(nn.Module): def __init__(self): super().__init__() self.conv1 = nn.Conv2d(3, 40, kernel_size=3, padding=1) self.bn1 = nn.BatchNorm2d(40) self.relu1 = nn.ReLU(inplace=True) self.conv2 = nn.Conv2d(40, 80, kernel_size=3, padding=1) self.bn2 = nn.BatchNorm2d(80) self.relu2 = nn.ReLU(inplace=True) def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu1(x) x = self.conv2(x) x = self.bn2(x) x = self.relu2(x) return x model = ConvNet() num_params_unpruned = sum(p.numel() for p in model.parameters()) dummy_input = torch.randn(1, 3, 32, 32) dummy_output = model(dummy_input) print(dummy_output.shape) # %% sparsity_ratio = 0.5 config_list = [{ 'op_types': ['Conv2d'], 'sparse_ratio': sparsity_ratio, }] config_list = auto_set_denpendency_group_ids(model, config_list, [dummy_input]) pruner = L1NormPruner(model, config_list) _, masks = pruner.compress() pruner.unwrap_model() model = ModelSpeedup(model, [dummy_input], masks, batch_size=1, garbage_collect_values=False).speedup_model() # %% num_params_pruned = sum(p.numel() for p in model.parameters()) print(f'Number of parameters before pruning: {num_params_unpruned}') print(f'Number of parameters after pruning: {num_params_pruned}') num_params_diff = num_params_unpruned - num_params_pruned prune_ratio = num_params_diff / num_params_unpruned print(f'Number of parameters pruned: {num_params_diff}') print(f'Parameter ratio: {(1-prune_ratio)*100:.2f}%') ```Output:
Regardless of the sparsity ratio, the model is not altered
Full Log
``` [2024-01-17 16:55:43] Start to speedup the model... [2024-01-17 16:55:43] Resolve the mask conflict before mask propagate... [2024-01-17 16:55:43] dim0 sparsity: 0.500000 [2024-01-17 16:55:43] dim1 sparsity: 0.000000 0 Filter [2024-01-17 16:55:43] dim0 sparsity: 0.500000 [2024-01-17 16:55:43] dim1 sparsity: 0.000000 [2024-01-17 16:55:43] Infer module masks... [2024-01-17 16:55:43] Propagate original variables [2024-01-17 16:55:43] Propagate variables for placeholder: x, output mask: 0.0000 [2024-01-17 16:55:43] Propagate variables for call_module: conv1, weight: 0.5000 bias: 0.5000 , output mask: 0.0000 [2024-01-17 16:55:44] Propagate variables for call_module: bn1, , output mask: 0.0000 [2024-01-17 16:55:44] Propagate variables for call_module: relu1, , output mask: 0.0000 [2024-01-17 16:55:44] Propagate variables for call_module: conv2, weight: 0.5000 bias: 0.5000 , output mask: 0.0000 [2024-01-17 16:55:44] Propagate variables for call_module: bn2, , output mask: 0.0000 [2024-01-17 16:55:44] Propagate variables for call_module: relu2, , output mask: 0.0000 [2024-01-17 16:55:44] Propagate variables for output: output, output mask: 0.0000 [2024-01-17 16:55:44] Update direct sparsity... [2024-01-17 16:55:45] Update direct mask for placeholder: x, output mask: 0.0000 [2024-01-17 16:55:45] Update direct mask for call_module: conv1, weight: 0.5000 bias: 0.5000 , output mask: 0.0000 [2024-01-17 16:55:45] Update direct mask for call_module: bn1, , output mask: 0.0000 [2024-01-17 16:55:45] Update direct mask for call_module: relu1, , output mask: 0.0000 [2024-01-17 16:55:45] Update direct mask for call_module: conv2, weight: 0.5000 bias: 0.5000 , output mask: 0.0000 [2024-01-17 16:55:45] Update direct mask for call_module: bn2, , output mask: 0.0000 [2024-01-17 16:55:45] Update direct mask for call_module: relu2, , output mask: 0.0000 [2024-01-17 16:55:46] Update direct mask for output: output, output mask: 0.0000 [2024-01-17 16:55:46] Update indirect sparsity... [2024-01-17 16:55:46] Update indirect mask for output: output, output mask: 0.0000 [2024-01-17 16:55:46] Update indirect mask for call_module: relu2, , output mask: 0.0000 [2024-01-17 16:55:46] Update indirect mask for call_module: bn2, , output mask: 0.0000 [2024-01-17 16:55:46] Update indirect mask for call_module: conv2, weight: 0.5000 bias: 0.5000 , output mask: 0.0000 [2024-01-17 16:55:47] Update indirect mask for call_module: relu1, , output mask: 0.0000 [2024-01-17 16:55:47] Update indirect mask for call_module: bn1, , output mask: 0.0000 [2024-01-17 16:55:47] Update indirect mask for call_module: conv1, weight: 0.5000 bias: 0.5000 , output mask: 0.0000 [2024-01-17 16:55:47] Update indirect mask for placeholder: x, output mask: 0.0000 [2024-01-17 16:55:47] Resolve the mask conflict after mask propagate... [2024-01-17 16:55:47] dim0 sparsity: 0.500000 [2024-01-17 16:55:47] dim1 sparsity: 0.000000 0 Filter [2024-01-17 16:55:47] dim0 sparsity: 0.500000 [2024-01-17 16:55:47] dim1 sparsity: 0.000000 [2024-01-17 16:55:47] Replace compressed modules... [2024-01-17 16:55:47] replace module (name: conv1, op_type: Conv2d) [2024-01-17 16:55:47] replace conv2d with in_channels: 3, out_channels: 40 [2024-01-17 16:55:47] replace module (name: bn1, op_type: BatchNorm2d) [2024-01-17 16:55:47] replace batchnorm2d with num_features: 40 [2024-01-17 16:55:47] replace module (name: relu1, op_type: ReLU) [2024-01-17 16:55:47] replace module (name: conv2, op_type: Conv2d) [2024-01-17 16:55:47] replace conv2d with in_channels: 40, out_channels: 80 [2024-01-17 16:55:47] replace module (name: bn2, op_type: BatchNorm2d) [2024-01-17 16:55:47] replace batchnorm2d with num_features: 80 [2024-01-17 16:55:47] replace module (name: relu2, op_type: ReLU) [2024-01-17 16:55:47] Speedup done. ```Expected Output:
If any
batch_size
above 1 is specified, or if the argument is not specified, then model is altered as expectedFull Log
``` [2024-01-17 17:07:12] Start to speedup the model... [2024-01-17 17:07:12] Resolve the mask conflict before mask propagate... [2024-01-17 17:07:12] dim0 sparsity: 0.500000 [2024-01-17 17:07:12] dim1 sparsity: 0.000000 0 Filter [2024-01-17 17:07:12] dim0 sparsity: 0.500000 [2024-01-17 17:07:12] dim1 sparsity: 0.000000 [2024-01-17 17:07:12] Infer module masks... [2024-01-17 17:07:12] Propagate original variables [2024-01-17 17:07:12] Propagate variables for placeholder: x, output mask: 0.0000 [2024-01-17 17:07:12] Propagate variables for call_module: conv1, weight: 0.5000 bias: 0.5000 , output mask: 0.0000 [2024-01-17 17:07:12] Propagate variables for call_module: bn1, , output mask: 0.0000 [2024-01-17 17:07:12] Propagate variables for call_module: relu1, , output mask: 0.0000 [2024-01-17 17:07:12] Propagate variables for call_module: conv2, weight: 0.5000 bias: 0.5000 , output mask: 0.0000 [2024-01-17 17:07:13] Propagate variables for call_module: bn2, , output mask: 0.0000 [2024-01-17 17:07:13] Propagate variables for call_module: relu2, , output mask: 0.0000 [2024-01-17 17:07:13] Propagate variables for output: output, output mask: 0.0000 [2024-01-17 17:07:13] Update direct sparsity... [2024-01-17 17:07:13] Update direct mask for placeholder: x, output mask: 0.0000 [2024-01-17 17:07:13] Update direct mask for call_module: conv1, weight: 0.5000 bias: 0.5000 , output mask: 0.5000 [2024-01-17 17:07:13] Update direct mask for call_module: bn1, , output mask: 0.5000 [2024-01-17 17:07:13] Update direct mask for call_module: relu1, , output mask: 0.5000 [2024-01-17 17:07:14] Update direct mask for call_module: conv2, weight: 0.5000 bias: 0.5000 , output mask: 0.5000 [2024-01-17 17:07:14] Update direct mask for call_module: bn2, , output mask: 0.5000 [2024-01-17 17:07:14] Update direct mask for call_module: relu2, , output mask: 0.5000 [2024-01-17 17:07:14] Update direct mask for output: output, output mask: 0.5000 [2024-01-17 17:07:14] Update indirect sparsity... [2024-01-17 17:07:14] Update indirect mask for output: output, output mask: 0.5000 [2024-01-17 17:07:14] Update indirect mask for call_module: relu2, , output mask: 0.5000 [2024-01-17 17:07:15] Update indirect mask for call_module: bn2, , output mask: 0.5000 [2024-01-17 17:07:15] Update indirect mask for call_module: conv2, weight: 0.7500 bias: 0.5000 , output mask: 0.5000 [2024-01-17 17:07:15] Update indirect mask for call_module: relu1, , output mask: 0.5000 [2024-01-17 17:07:16] Update indirect mask for call_module: bn1, , output mask: 0.5000 [2024-01-17 17:07:16] Update indirect mask for call_module: conv1, weight: 0.5000 bias: 0.5000 , output mask: 0.5000 [2024-01-17 17:07:16] Update indirect mask for placeholder: x, output mask: 0.0000 [2024-01-17 17:07:16] Resolve the mask conflict after mask propagate... [2024-01-17 17:07:16] dim0 sparsity: 0.500000 [2024-01-17 17:07:16] dim1 sparsity: 0.465116 [2024-01-17 17:07:16] WARNING: both dim0 and dim1 masks found. 0 Filter [2024-01-17 17:07:16] dim0 sparsity: 0.500000 [2024-01-17 17:07:16] dim1 sparsity: 0.465116 [2024-01-17 17:07:16] WARNING: both dim0 and dim1 masks found. [2024-01-17 17:07:16] Replace compressed modules... [2024-01-17 17:07:16] replace module (name: conv1, op_type: Conv2d) [2024-01-17 17:07:16] replace conv2d with in_channels: 3, out_channels: 20 [2024-01-17 17:07:16] replace module (name: bn1, op_type: BatchNorm2d) [2024-01-17 17:07:16] replace batchnorm2d with num_features: 20 [2024-01-17 17:07:16] replace module (name: relu1, op_type: ReLU) [2024-01-17 17:07:16] replace module (name: conv2, op_type: Conv2d) [2024-01-17 17:07:16] replace conv2d with in_channels: 20, out_channels: 40 [2024-01-17 17:07:16] replace module (name: bn2, op_type: BatchNorm2d) [2024-01-17 17:07:16] replace batchnorm2d with num_features: 40 [2024-01-17 17:07:16] replace module (name: relu2, op_type: ReLU) [2024-01-17 17:07:16] Speedup done. ```