microsoft / nni

An open source AutoML toolkit for automate machine learning lifecycle, including feature engineering, neural architecture search, model compression and hyper-parameter tuning.
https://nni.readthedocs.io
MIT License
14k stars 1.81k forks source link

How to prune the Attention matrix using `nni`. #5589

Closed hobbitlzy closed 1 year ago

hobbitlzy commented 1 year ago

Describe the issue: I am currently working on the sparsity of the attention that is defined as $P=QK^T$. Many works have proposed sparse attention to save the square complexity of memory and computation. A binary mask matrix is generated indicating which elements need to be computed in $P$. This kind of sparsity differs from the currently available nni API, which prunes a module with weights. Therefore, I am wondering whether nni supports such sparse computation.

A possible way I am considering is to define a new nn.module, in which I define an all-one matrix as its parameter, and its forward is the Hadamard product of this all-one matrix and the post-softmax attention. Then the pruner could be applied to the all-one matrix (The metric to generate mask should also be modified). I thing the problem may rises from the mask of pruner is generated before the module is forwarded. However, generating the mask of the Softmax requires post-softmax attention. What do you think the feasibility of this implementation and is there any more elegant way?

Environment: Not applied.

Configuration: Not applied.

Log message: Not applied.

How to reproduce it?: Not applied.

J-shang commented 1 year ago

hello @hobbitlzy , it is a good question, your concern about the mask of pruner is generated before the module is forwarded is only for post-training pruners, you could have a try with some training-aware pruners like MovementPruner, we have tested the performance of attention pruning on this pruner in an internal nni version, or maybe you have your own pruning algorithm then you could customize a pruner.

I think your solution for pruning Softmax is the fastest practical way, you could try your solution first.

If you want to use nni's native functionality to achieve attention pruning, the following doc may help you sparse the Softmax, but note that pruners in nni 3.0rc1 have not been tested on attention pruning yet, you may need to modify part of the implementation logic of pruners.

If you want to mask the input and output of torch.nn.Softmax, you need add a setting for torch.nn.Softmax, https://nni.readthedocs.io/en/v3.0rc1/compression/setting.html

customized_setting = {
    '_input_': {
        'sparse_ratio': None,
        'max_sparse_ratio': None,
        'min_sparse_ratio': None,
        'sparse_threshold': None,
        'global_group_id': None,
        'dependency_group_id': None,
        'granularity': [-1, 1, 1],
        'internal_metric_block': None,
        'apply_method': 'add',
    },
    '_output_': {
        'align': {
            'module_name': None,
            'target_name': '_input_0',
            'dims': [0, 1, 2],
        },
        'apply_method': 'mul',
        'granularity': [-1, 1, 1]
    }
}

You could have a try with setting compression target in config list. https://nni.readthedocs.io/en/v3.0rc1/compression/config_list.html#target-names

And for Softmax input pruning, should change masks apply_method from mul to add. https://nni.readthedocs.io/en/v3.0rc1/compression/config_list.html#apply-method

hobbitlzy commented 1 year ago

Thank you for your response. It helps me discover many nni functionalities that I hadn't previously been aware of. It seems plausible that I can employ the input or output pruning to the SoftMax module directly. This strategy should allow me to achieve the exact attention pruning I am aiming for. I will test this approach and see how it works.

hobbitlzy commented 1 year ago

Hi, @J-shang. I read the paper of the MovementPruner and find it is designed for model weights. I am not sure about its performance on attention pruning (Do you get satisfying results?). Therefore, in my early attempt, I want to use the simple magnitude metric to prune the unnecessary attention values, such as the LevelPruner or the L1NormPruner do. Thanks to your snippet of the setting, I successfully wrap the SoftMax module like this.


(self): BertSelfAttention(
    (query): Linear(in_features=768, out_features=768, bias=True)
    (key): Linear(in_features=768, out_features=768, bias=True)
    (value): Linear(in_features=768, out_features=768, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (softmax): MySoftMax(
        dim=-1
        (_nni_wrapper): ModuleWrapper(module=MySoftMax(dim=-1), module_name=bert.encoder.layer.11.attention.self.softmax)
    )
)

However, when I dive into the implementation of the input/output pruning, I find the pruner is hard to generate masks according to the attention values since it does not accept data when generating masks. Do you have any suggestions about this?

I guess the MovementPruner can do this since it intervenes in each step of the training loop and can accept the module input for generating masks. Maybe I need to define the pruning logic based on the MovementPruner?

J-shang commented 1 year ago

Hello @hobbitlzy , yes, MovementPruner is implemented for pruning weights, so you need to define the pruning logic for inputs. You could call this interface pruner.track_forward to get the input/output shape before pruner.compress, the input/output shape will be recorded in TargetSpace.shape. Please pay attention to the variable sequence length.

About the MovementPruner performance on attention, we have tried on Bert before, and it's not bad.

hobbitlzy commented 1 year ago

Thanks for the reply. Sorry, but I do not quite understand the reason to use `pruner.track_forward' here. In my understanding, it generates the shape or something else once before the compression. But I want to mask the attention based on the values itself for each finetuning step.

I read the logic of the MovementPruner to find its registers cannot intervene inside the inference, could you explain how it is applied to attention pruning?

hobbitlzy commented 1 year ago

Update: I write the attention prune logic inside the forward function of the model, and it works well for me.