VainF / Torch-Pruning

[CVPR 2023] Towards Any Structural Pruning; LLMs / SAM / Diffusion / Transformers / YOLOv8 / CNNs
https://arxiv.org/abs/2301.12900
MIT License
2.59k stars 318 forks source link

MultiheadAttentionPruner implementation does "random pruning" in inner dimensions of MHA #196

Open LukasHedegaard opened 1 year ago

LukasHedegaard commented 1 year ago

Hi, thanks for contributing a great library!

I've been doing a close-up study of the MultiheadAttentionPruner implementation, and I have some concerns.

The pruning of output channel in out_proj makes sense, but why should the same idxs be used for its input channels? They aren't "grouped" (you wouldn't do it for a stand-alone nn.Linear). The same issue applies to the in_proj_weights. While it would make sense to prune the input dimension, the pruning of the output dimension seems random.

At this point, the fix isn't obvious, as the torch implementation of nn.MultiheadAttention doesn't allow the differentiation of an "inner" embed_dim, "input" embed_dim and "output" embed_dim. (See this issue). Any pruning of out_proj.weight and in_proj_weights must abide by the same embed_dim to satisfy the (too restrictive) assertions in nn.MultiheadAttention. This means that we essentially can't fix the issue for nn.MultiheadAttention. However, I think it is still worth noting that the current pruning implementation performs a "pseudo-random" pruning for out_proj.in_features and in_proj_weights out_features.

Moreover, there is a typo in line 402 where k_proj_weight is assigned to q_proj_weight. Also, to keep the implementation consistent with the pruning of in_proj_weight, shouldn't both the input and output dimensions be pruned in q_proj_weight, k_proj_weight, and v_proj_weight?

VainF commented 1 year ago

Hi @LukasHedegaard, thank you for providing the information.

Indeed, we manually group the input channels and output channels to simplify the implementation. This is allowed in DepGraph, as long as the importance score is estimated on all the grouped parameters.

So far TP has limited support for vision transformers because there are so many customized / re-implemented MHA layers in the community, e.g., the Attention layers in SAM. Most people use their own implementation instead of the torch.nn.MultiHeadAttention. I think we need a more fine-grained implementation for MHA pruning to avoid defining case-by-case pruners for different transformers.