lix19937 / tensorrt-insight

Deep insight tensorrt, including but not limited to qat, ptq, plugin, triton_inference, cuda
12 stars 0 forks source link

ASP for mm-x (mmcv mmdet mmseg) KeyError: <class 'mmcv.cnn.bricks.wrappers.Linear'> #49

Open lix19937 opened 3 weeks ago

lix19937 commented 3 weeks ago

The mmcv/mmdet/mmseg/detectron2 wrapper's module type is not present in the default sparse_parameter_list. You can add the required module->parameter mapping with the custom_layer_dict argument to init_model_for_pruning().
You can see how this argument is used here.

To do this, you'll need to call the 3 constituent functions inside of prune_trained_model() (instead of this convenience function) so you can call init_model_for_pruning() directly to supply your custom_layer_dict.

apex/contrib/sparsity/asp.py

import mmcv
import detectron2 

    @classmethod
    def prune_trained_model_v2(cls, model, optimizer):
        # add mask buffers to model (init_model_for_pruning), augment optimizer (init_optimizer_for_pruning) and compute masks (compute_sparse_masks)
        cls.init_model_for_pruning(model, mask_calculator="m4n2_1d", verbosity=2,
            whitelist=[torch.nn.Linear, torch.nn.Conv2d, torch.nn.MultiheadAttention],
            allow_recompute_mask=False,
            custom_layer_dict={mmcv.cnn.bricks.wrappers.Linear:['weight'], detectron2.layers.wrappers.Conv2d:['weight']})
        cls.init_optimizer_for_pruning(optimizer)
        cls.compute_sparse_masks()