calgaryml / condensed-sparsity

[ICLR 2024] Dynamic Sparse Training with Structured Sparsity
https://openreview.net/forum?id=kOBkxFRKTA
MIT License
17 stars 3 forks source link

Is there code for seeing the performance when sparsifying the multihead attention layers too ? #77

Closed mayank64ce closed 3 months ago

mklasby commented 3 months ago

Hi @mayank64ce, By default SRigL will prune the input and output projections of the MHA module. This is controlled by the config parameter rigl.ignore_mha_layers (https://github.com/calgaryml/condensed-sparsity/blob/main/configs/rigl/vit.yaml#L12).

The torch MHA module is parsed here: https://github.com/calgaryml/condensed-sparsity/blob/main/src/rigl_torch/utils/rigl_utils.py#L67

Since the weights for this module are set as attributes in_proj_weight and a submodule containing out_proj_weight, it's a bit convoluted.

If you mean to sparsify the scaled dot product attention then no, this is not supported in this work.

mayank64ce commented 3 months ago

No I don't mean the scaled dot product, I mean the Query, Key, Value weight matrices themselves. Like from my understanding, a transformer block has W_q , W_k, W_v (the query key value weight matrices), W_o (the output projection), W_f1 and W_f2 (and the feed forward layers).

So which ones do you sparsify ? I wanted to know if I can sparsify the W_q, W_k and W_v matrices.

Please correct me if I am wrong because I am also new to the feild.

mklasby commented 3 months ago

@mayank64ce in_proj_weight for the torch MHA are the the Q, K, V weights that you're asking about. For torch MHA module, the in_proj_weight is 3x the model embedding dimension. With other models, such as DeiT from huggingface, we have 3 weight matrices, one for each of the Q, K, V as you note above which are each 1x the model embedding dimension. In practice, these weights are used in an identical fashion and are equivalent in their operation, the in_proj_weight is simply the concatenation of Q, K, V weights.

So to make a long story short, yes, you can sparsify any weight parameter that is associated with the model, including the Q, K, V weights.