Closed CaffreyR closed 1 year ago
Hi, thanks for checking out our repo! I think T5 would be a bit tricky as it's a sequence-to-sequence model. So you would need to add masks to components on both the source and target side of the model, but the logic to add the masks to both sides is the same.
I suggest you check out models/modeling_bert.py
and models/l0_module.py
; the first file defines the model structure and the masking strategy, and the second file defines the masks and the underlying distribution. I am happy to help answer questions if you encounter any issues!
Hi @xiamengzhou thanks for your kind reply. In your code, do you use prune_head
to directly cut the whole Query
and use prune_layer
to cut the element in the Query
matrix? BTW, all the pruning rules are defined before run the code? Many thanks again!
Sorry to bother again, may I ask that how do you prune the model during the process. I see in this line you obtain the zs
.
https://github.com/princeton-nlp/CoFiPruning/blob/main/trainer/trainer.py#L282
But then how to use the zs
to prune the model, like you did in cofi_utils.py
?
https://github.com/princeton-nlp/CoFiPruning/blob/main/utils/cofi_utils.py#L107
Many thanks !!! @xiamengzhou
Hi,
Sorry for the late reply!
For the first question: I pruned the head and layers at the same time, with a set of masks head_loga
to control the pruning decision for heads and another set of masks headlayer_loga
to control the pruning decision of the whole multi-head attention.
For the second question: in each training forward pass, we get the sampled mask zs from l0_module, and pass it to the model to get the loss, and backward the loss to update the model parameters and parameters of l0_modules. You can use load_model
in cofi_utils.py
to prune a model with zs!
Hi @xiamengzhou , thanks for your contribution. But in your code, you use
Model.from_pretrained
to load the model architecture, and the files you have already provided. But if I want to prune my own, original model, for instance T5 model, using your method in the paper. Which code should I check? Many thanks:)