VainF / Torch-Pruning

[CVPR 2023] Towards Any Structural Pruning; LLMs / SAM / Diffusion / Transformers / YOLOv8 / CNNs
https://arxiv.org/abs/2301.12900
MIT License
2.6k stars 321 forks source link

How to get the indexes of the pruned modules in the initial network #175

Closed jiangz20 closed 1 year ago

jiangz20 commented 1 year ago

Firstly, thank you for your contribution to structural pruning! It really helps me a lot. Then I'd like to ask a question: given the network and pruning ratio, I can prune the network using your framework. However, I may want to restore the subnetwork to its initial size in later steps. So it's necessary for me to record the indexes of the pruned (or not pruned) modules in the initial network. I wonder how I can get the indexes in your framework.

VainF commented 1 year ago

Something like this:

pruning_record = []
for group in pruner.step(interactive=True): 
    print(group) 
    dep, idxs = group[0]
    target_module = dep.target.module
    pruning_fn = dep.handler
    pruning_record.append((target_module, pruning_fn, idxs))
    group.prune()

Now you can retrieve these groups from the origianl model by

for (target_module, pruning_fn, idxs) in pruning_record:
    group = DG.get_pruning_group(target_module, pruning_fn, idxs)
jiangz20 commented 1 year ago

Thank you! I'll try that!

jiangz20 commented 1 year ago

oh sorry~ I may have to trouble you again. Following your instructions just now, I can fetch the pruned groups. Since my goal is to transfer the weights in the submodel ( i may train it for several rounds) to the original model while keeping weights in the pruned module unchanged, how can i achieve that?