Closed mayank64ce closed 3 months ago
No I don't mean the scaled dot product, I mean the Query, Key, Value weight matrices themselves. Like from my understanding, a transformer block has W_q
, W_k
, W_v
(the query key value weight matrices), W_o
(the output projection), W_f1
and W_f2
(and the feed forward layers).
So which ones do you sparsify ? I wanted to know if I can sparsify the W_q, W_k and W_v matrices.
Please correct me if I am wrong because I am also new to the feild.
@mayank64ce
in_proj_weight
for the torch MHA are the the Q, K, V weights that you're asking about. For torch MHA module, the in_proj_weight
is 3x the model embedding dimension. With other models, such as DeiT from huggingface, we have 3 weight matrices, one for each of the Q, K, V as you note above which are each 1x the model embedding dimension. In practice, these weights are used in an identical fashion and are equivalent in their operation, the in_proj_weight
is simply the concatenation of Q, K, V weights.
So to make a long story short, yes, you can sparsify any weight parameter that is associated with the model, including the Q, K, V weights.
Hi @mayank64ce, By default SRigL will prune the input and output projections of the MHA module. This is controlled by the config parameter
rigl.ignore_mha_layers
(https://github.com/calgaryml/condensed-sparsity/blob/main/configs/rigl/vit.yaml#L12).The torch MHA module is parsed here: https://github.com/calgaryml/condensed-sparsity/blob/main/src/rigl_torch/utils/rigl_utils.py#L67
Since the weights for this module are set as attributes
in_proj_weight
and a submodule containingout_proj_weight
, it's a bit convoluted.If you mean to sparsify the scaled dot product attention then no, this is not supported in this work.