princeton-nlp / CoFiPruning

[ACL 2022] Structured Pruning Learns Compact and Accurate Models https://arxiv.org/abs/2204.00408
MIT License
188 stars 32 forks source link

Removing the already-pruned parts in the model may cause some changes in the outputs #36

Closed backspacetg closed 1 year ago

backspacetg commented 1 year ago

Hi! I am trying to apply CoFi pruning to my own model, and I noticed that there might exist some edge cases where removing the already-pruned parts in my model will cause some changes in the outputs. I think this will happen when all the dims of the intermediate layer are removed.

I found that when intermediate_zs are all zero, the intermediate.dense in the pruned model is set to None https://github.com/princeton-nlp/CoFiPruning/blob/5423094e7b318806462f2a7bdca5384d078e5eed/utils/cofi_utils.py#L229-L231 , and the FFN parts will then be skipped https://github.com/princeton-nlp/CoFiPruning/blob/5423094e7b318806462f2a7bdca5384d078e5eed/models/modeling_bert.py#L364-L365

But before pruning, intermediate.dense is not None, and these zero outputs will still pass through CoFiBertOutput.dense which add a bias to the output https://github.com/princeton-nlp/CoFiPruning/blob/5423094e7b318806462f2a7bdca5384d078e5eed/models/modeling_bert.py#L562-L566 , so the FFN parts are not skipped.

Should I change some part of my code to skip the FFN parts when intermediate_zs are all zero during training?

xiamengzhou commented 1 year ago

Thanks for spotting this! It's a bug I think, though it almost never happens -> mlp_z always turns 0 first before all the intermediate_z become 0 during training. And you are correct that the FFN parts should be skipped when intermediate_z are all zero during training.

I am refactoring the code and should have a clear version ready soon!

xiamengzhou commented 1 year ago

Changing this line to be if self.intermediate.dense is None or self.intermediate_z.sum().eq(0).item() or self.mlp_z.sum().eq(0).item(): should be sufficient.

backspacetg commented 1 year ago

I see! thanks for your reply😃