Open LukasHedegaard opened 1 year ago
Hi @LukasHedegaard, thank you for providing the information.
Indeed, we manually group the input channels and output channels to simplify the implementation. This is allowed in DepGraph, as long as the importance score is estimated on all the grouped parameters.
So far TP has limited support for vision transformers because there are so many customized / re-implemented MHA layers in the community, e.g., the Attention layers in SAM. Most people use their own implementation instead of the torch.nn.MultiHeadAttention. I think we need a more fine-grained implementation for MHA pruning to avoid defining case-by-case pruners for different transformers.
Hi, thanks for contributing a great library!
I've been doing a close-up study of the
MultiheadAttentionPruner
implementation, and I have some concerns.The pruning of output channel in out_proj makes sense, but why should the same idxs be used for its input channels? They aren't "grouped" (you wouldn't do it for a stand-alone
nn.Linear
). The same issue applies to the in_proj_weights. While it would make sense to prune the input dimension, the pruning of the output dimension seems random.At this point, the fix isn't obvious, as the torch implementation of
nn.MultiheadAttention
doesn't allow the differentiation of an "inner" embed_dim, "input" embed_dim and "output" embed_dim. (See this issue). Any pruning of out_proj.weight and in_proj_weights must abide by the same embed_dim to satisfy the (too restrictive) assertions innn.MultiheadAttention
. This means that we essentially can't fix the issue fornn.MultiheadAttention
. However, I think it is still worth noting that the current pruning implementation performs a "pseudo-random" pruning for out_proj.in_features and in_proj_weights out_features.Moreover, there is a typo in line 402 where k_proj_weight is assigned to q_proj_weight. Also, to keep the implementation consistent with the pruning of in_proj_weight, shouldn't both the input and output dimensions be pruned in q_proj_weight, k_proj_weight, and v_proj_weight?