microsoft / nni

An open source AutoML toolkit for automate machine learning lifecycle, including feature engineering, neural architecture search, model compression and hyper-parameter tuning.
https://nni.readthedocs.io
MIT License
14.07k stars 1.82k forks source link

AssertionError, while speeding up Convnext model. #5417

Open ankitknitj opened 1 year ago

ankitknitj commented 1 year ago

Describe the issue: AssertionError, while speeding up Convnext model.

Environment:

Log message: [2023-03-01 00:14:15] start to speedup the model [2023-03-01 00:14:16] infer module masks... [2023-03-01 00:14:16] Update mask for features.0.0 [2023-03-01 00:14:16] Update mask for features.0.1 [2023-03-01 00:14:16] Update mask for features.1.0.block.0 [2023-03-01 00:14:16] Update mask for features.1.0.block.1 [2023-03-01 00:14:16] Update mask for features.1.0.block.2 [2023-03-01 00:14:16] Update mask for features.1.0.block.3 [2023-03-01 00:14:16] Update mask for features.1.0.block.4 [2023-03-01 00:14:16] Update mask for features.1.0.block.5 [2023-03-01 00:14:16] Update mask for features.1.0.block.6 [2023-03-01 00:14:16] Update mask for features.1.0.aten::mul.138 Traceback (most recent call last): File "/scratch/project_2005599/Ankit/convnext_try/DeeperCompression/extract_metadata.py", line 559, in main_prepare_metadata(dataset_list, model_dict, combo, File "/scratch/project_2005599/Ankit/convnext_try/DeeperCompression/extract_metadata.py", line 331, in main_prepare_metadata model_copy, complexity_dict = prune_tune_quant_tune('dorefa', 'taylorfo', model_name, model_copy, sparsity, quant_prec, tune_set, tune_epochs_taylor, File "/scratch/project_2005599/Ankit/convnext_try/DeeperCompression/compression_methods/compress.py", line 163, in prune_tune_quant_tune pruned_model, pruned_inf_time, pruner_obj = gradual_prune_model(pruning_method, model, model_name, target_sparsity, dataset, File "/scratch/project_2005599/Ankit/convnext_try/DeeperCompression/compression_methods/compress.py", line 61, in gradual_prune_model pruned_model, pruner_obj = pruner.process() File "/scratch/project_2005599/Ankit/convnext_try/DeeperCompression/compression_methods/nni_compression.py", line 176, in process ModelSpeedup(self.model, dummy_input=(self.dummy_input), masks_file=masks).speedup_model() File "/users/khatrian/.local/lib/python3.9/site-packages/nni/compression/pytorch/speedup/compressor.py", line 546, in speedup_model self.infer_modules_masks() File "/users/khatrian/.local/lib/python3.9/site-packages/nni/compression/pytorch/speedup/compressor.py", line 383, in infer_modules_masks self.update_direct_sparsity(curnode) File "/users/khatrian/.local/lib/python3.9/site-packages/nni/compression/pytorch/speedup/compressor.py", line 237, in update_direct_sparsity _auto_infer = AutoMaskInference( File "/users/khatrian/.local/lib/python3.9/site-packages/nni/compression/pytorch/speedup/infer_mask.py", line 80, in init self.output = self.module(*dummy_input) File "/users/khatrian/.local/lib/python3.9/site-packages/nni/compression/pytorch/speedup/jit_translate.py", line 227, in call assert len(args) >= len(self.undetermined) AssertionError

How to reproduce it?: This how we load convnext model.

from torchvision import models as torch_models model = torch_models.convnext_tiny(weights='IMAGENET1K_V1', progress=True).to(device)

Pruning using Level pruner. Have also tried TaylorFO,Slim and L1 pruner. Nothing works. if self.pruning_method == 'level': config_list = [{'op_types': ['default'], 'sparsity': self.sparsity}] pruner = LevelPruner(self.model, configlist) , masks = pruner.compress() print("Completed pruner.compress()") criterion = torch.nn.CrossEntropyLoss() print('\n' + '=' 50 + ' START TO FINE TUNE THE MODEL ' + '=' 50) optimizer, scheduler = optimizer_scheduler_generator(self.model, _lr=1e-3, total_epoch=20) self.train_func(self.model, optimizer, criterion, scheduler, max_epochs=20) print("Pruned model accuracy on tuning dataset: ", self.eval_func(self.model)) pruner._unwrap_model() ModelSpeedup(self.model, dummy_input=(self.dummy_input), masks_file=masks).speedup_model() return self.model, pruner

ankitknitj commented 1 year ago

@J-shang @Louis-J

J-shang commented 1 year ago

hello @ankitknitj , we are working on a new speedup version, and we will let you know if the new version works after testing on your model.

J-shang commented 1 year ago

hello @ankitknitj , please have a try on this branch https://github.com/microsoft/nni/pull/5403 and because there has customized module in convnext, some additional customized logics are needed.

import torch
from torchvision.models import convnext_tiny
from torchvision.models.convnext import LayerNorm2d

from nni.contrib.compression.pruning import L1NormPruner
from nni.contrib.compression.utils import auto_set_denpendency_group_ids
from nni.common.concrete_trace_utils import concrete_trace
from nni.compression.pytorch.speedup.v2 import ModelSpeedup
from nni.compression.pytorch.speedup.v2.mask_updater import NoChangeMaskUpdater
from nni.compression.pytorch.speedup.v2.replacer import Replacer, DefaultReplacer

class CNBlockReplacer(Replacer):
    def replace_modules(self, speedup: 'ModelSpeedup'):
        # replace layer_scale in CNBlock
        from torchvision.models.convnext import CNBlock

        name2node_info = {node_info.node.target: node_info for node_info in speedup.node_infos.values() if node_info.node.op == 'call_module'}
        for module_name, module in speedup.bound_model.named_modules():
            if isinstance(module, CNBlock):
                block_5_name = module_name + '.block.5'
                output_mask = name2node_info[block_5_name].output_masks
                reduce_dims = list(range(len(output_mask.shape)))
                reduce_dims.pop(-1)
                module.layer_scale.data = module.layer_scale.data[output_mask.sum(reduce_dims) != 0].clone().contiguous()

class LayerNorm2dReplacer(DefaultReplacer):
    def __init__(self):
        replace_module_func_dict = {'LayerNorm2d': self.replace_layernorm2d}
        super().__init__(replace_module_func_dict)

    @classmethod
    def replace_layernorm2d(cls, layernorm, masks):
        in_masks, out_masks, _ = masks
        assert isinstance(layernorm, LayerNorm2d)
        in_mask = in_masks[0]
        out_mask = out_masks

        reduce_dims = list(range(len(in_mask.shape)))
        reduce_dims.pop(1)
        in_idxes = in_mask.sum(reduce_dims) != 0

        reduce_dims = list(range(len(out_mask.shape)))
        reduce_dims.pop(1)
        out_idxes = out_mask.sum(reduce_dims) != 0

        assert torch.equal(in_idxes, out_idxes)

        new_normalized_shape = [in_idxes.sum().item()]
        new_layernorm = LayerNorm2d(tuple(new_normalized_shape), layernorm.eps, layernorm.elementwise_affine)

        if new_layernorm.elementwise_affine:
            new_layernorm.to(layernorm.weight.device)
            with torch.no_grad():
                new_layernorm.weight.data = layernorm.weight.data[in_idxes].clone().contiguous()
                new_layernorm.bias.data = layernorm.bias.data[in_idxes].clone().contiguous()
        return new_layernorm

if __name__ == '__main__':
    model = convnext_tiny(weights='IMAGENET1K_V1', progress=True)
    config_list = [{
        'op_types': ['Conv2d', 'Linear'],
        'op_names_re': ['features.*'],
        'sparse_ratio': 0.5
    }]
    dummy_input = torch.rand(8, 3, 224, 224)
    config_list = auto_set_denpendency_group_ids(model, config_list, dummy_input)
    pruner = L1NormPruner(model, config_list)
    _, masks = pruner.compress()
    pruner.unwrap_model()

    cnb_replacer = CNBlockReplacer()
    layernorm2d_replacer = LayerNorm2dReplacer()
    mask_updater = NoChangeMaskUpdater(customized_no_change_act_module=(LayerNorm2d,))

    from torchvision.models.convnext import LayerNorm2d
    graph_module = concrete_trace(model, (dummy_input,), leaf_module=(LayerNorm2d,))

    model = ModelSpeedup(model, dummy_input, masks,
                         customized_mask_updaters=[mask_updater], customized_replacers=[cnb_replacer, layernorm2d_replacer],
                         graph_module=graph_module).speedup_model()

    model(dummy_input)
ankitknitj commented 1 year ago

I am getting this error. Although nni 2.10 is already installed. ModuleNotFoundError: No module named 'nni.contrib'

J-shang commented 1 year ago

hello @ankitknitj , please install from master branch, these features have not be released, https://nni.readthedocs.io/en/stable/notes/build_from_source.html

funnym0nk3y commented 1 year ago

I'm having the same issue, also with version 2.10. Unfortunately I couldn't get the dev version from the matter branch to work in my python environment.

J-shang commented 1 year ago

hello @funnym0nk3y , what issue you met during building from source code?

funnym0nk3y commented 1 year ago

I got the dev version to build. But it did not change the original issue.

J-shang commented 1 year ago

could you show any error msg? your original issue is ModuleNotFoundError: No module named 'nni.contrib'? If so, I think you did not install from source code successfully.

funnym0nk3y commented 1 year ago

This is the trace:

---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
Cell In[7], line 44
     41 _, masks = pruner.compress()
     42 print('Done creating masks')
---> 44 ModelSpeedup(best_person_only_model, dummy_input=best_person_only_model.example_input_array, masks_file=masks).speedup_model()
     46 # show the masks sparsity
     47 for name, mask in masks.items():

File [~/nni/nni/compression/pytorch/speedup/compressor.py:488](https://vscode-remote+ssh-002dremote-002b7b22686f73744e616d65223a224743502d37227d.vscode-resource.vscode-cdn.net/home/funnymonkey/eml/~/nni/nni/compression/pytorch/speedup/compressor.py:488), in ModelSpeedup.speedup_model(self)
    485 fix_mask_conflict(self.masks, self.bound_model, self.dummy_input)
    487 _logger.info("infer module masks...")
--> 488 self.infer_modules_masks()
    489 _logger.info('resolve the mask conflict')
    490 # sometimes, mask conflict will happen during infer masks
    491 # fix_mask_conflict(self.masks, self.bound_model, self.dummy_input)
    492 
    493 # load the original stat dict before replace the model

File [~/nni/nni/compression/pytorch/speedup/compressor.py:398](https://vscode-remote+ssh-002dremote-002b7b22686f73744e616d65223a224743502d37227d.vscode-resource.vscode-cdn.net/home/funnymonkey/eml/~/nni/nni/compression/pytorch/speedup/compressor.py:398), in ModelSpeedup.infer_modules_masks(self)
    396 curnode = visit_queue.get()
    397 # forward mask inference for curnode
--> 398 self.update_direct_sparsity(curnode)
    399 successors = self.torch_graph.find_successors(curnode.unique_name)
    400 for successor in successors:

File [~/nni/nni/compression/pytorch/speedup/compressor.py:252](https://vscode-remote+ssh-002dremote-002b7b22686f73744e616d65223a224743502d37227d.vscode-resource.vscode-cdn.net/home/funnymonkey/eml/~/nni/nni/compression/pytorch/speedup/compressor.py:252), in ModelSpeedup.update_direct_sparsity(self, node)
    250         return
    251     # function doesn't have weights
--> 252     _auto_infer = AutoMaskInference(
    253         func, dummy_input, self, in_masks, in_constants=in_constants)
    254 else:
    255     weight_mask = None

File [~/nni/nni/compression/pytorch/speedup/infer_mask.py:80](https://vscode-remote+ssh-002dremote-002b7b22686f73744e616d65223a224743502d37227d.vscode-resource.vscode-cdn.net/home/funnymonkey/eml/~/nni/nni/compression/pytorch/speedup/infer_mask.py:80), in AutoMaskInference.__init__(self, module, dummy_input, speedup, in_masks, weight_mask, output_mask, name, in_constants, state_dict)
     76         self.in_masks[in_id] = torch.ones_like(self.dummy_input[in_id])
     77         # ones_like will put the created mask on the same device with the dummy_input
     78 
     79 # Initialize the mask for output tensors
---> 80 self.output = self.module(*dummy_input)
     81 # self.output.requires_grad_()
     82 if output_mask is not None:
     83     # assume the given output mask is right

File [~/nni/nni/compression/pytorch/speedup/jit_translate.py:227](https://vscode-remote+ssh-002dremote-002b7b22686f73744e616d65223a224743502d37227d.vscode-resource.vscode-cdn.net/home/funnymonkey/eml/~/nni/nni/compression/pytorch/speedup/jit_translate.py:227), in FuncAdapter.__call__(self, *args)
    226 def __call__(self, *args):
--> 227     assert len(args) >= len(self.undetermined)
    228     if len(args) > len(self.undetermined):
    229         logger.warning('throw some args away when calling the function "%s"', self.func.__name__)

AssertionError:

And this is the output of nni up to that point:

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Done creating masks
[2023-03-10 10:13:11] start to speedup the model
no multi-dimension masks found.
[2023-03-10 10:13:13] infer module masks...
[2023-03-10 10:13:13] Update mask for conv1.aten::mul.32

This seems to be the same error as in the initial issue. I got the dev version working as nni.__version__ shows 999.dev0

J-shang commented 1 year ago

hello @funnym0nk3y , please import from v2 like from nni.compression.pytorch.speedup.v2 import ModelSpeedup, seems that you use the old speedup version.

ankitknitj commented 1 year ago

Hi @funnym0nk3y , @J-shang , I am getting this error while building from source.

while running , python setup.py develop

File "setup.py", line 88 sys.exit(f'ERROR: To build a jupyter lab extension, run "JUPYTER_LAB_VERSION={jupyter_lab_version}", current: {environ_version} ') ^ SyntaxError: invalid syntax

J-shang commented 1 year ago

hello @ankitknitj , please have a try pip install jupyterlab==3.0.9

ankitknitj commented 1 year ago

Hi @J-shang , which version should i mention here,

In bash:

export NNI_RELEASE=2.0 python setup.py build_ts python setup.py bdist_wheel

J-shang commented 1 year ago

In fact, I think python setup.py develop is enough, if you want to build wheel, you could set any version number you like, i.e.3.0a.

@liuzhe-lz could you please help how to fix this error?

File "setup.py", line 88
sys.exit(f'ERROR: To build a jupyter lab extension, run "JUPYTER_LAB_VERSION={jupyter_lab_version}", current: {environ_version} ')
^
SyntaxError: invalid syntax
ankitknitj commented 1 year ago

@J-shang I am able to run model speedup for ConvNeXt, but somehow the number of parameters and FLOPs are not reduced after the speedup compared to the original model. What might be happening there?

ankitknitj commented 1 year ago

@J-shang RuntimeError: output with shape [768, 1, 1, 1] doesn't match the broadcast shape [768, 1, 768, 1] , I am getting this error in pruner.compress