VainF / Torch-Pruning

[CVPR 2023] Towards Any Structural Pruning; LLMs / SAM / Diffusion / Transformers / YOLOv8 / CNNs
https://arxiv.org/abs/2301.12900
MIT License
2.64k stars 329 forks source link

Does Torch-Pruning support sharing weight between different pruning structures? #257

Closed ChengHan111 closed 1 year ago

ChengHan111 commented 1 year ago

Hi, first of all, thanks for your great work! I really learn a lot from your paper. I am wondering if torch-pruning support to optimize a single model with different sparsities, which means the architecture itself are sparsed from a single backbone, while all sparsities share the same weight (for higher sparisity, only partial from the original model is updated accordingly). The question is a little bit tricky. I wish you could give me some hints.

Thank you!

VainF commented 1 year ago

Hi @ChengHan111. Could you please share additional information regarding 'different sparsities'? Are you referring to structured and unstructured sparsity?

Or perhaps you're interested in further sparsifying an already sparse model.

ChengHan111 commented 1 year ago

Thanks for your reply, I am doing the structural pruning. What I mean different sparsities is that I want to jointly prune a model with different sparsities. Without initiating multiple models, I want to share the weights across different sparsities. Any thoughts? Thank you!

ChengHan111 commented 1 year ago

Also, I do interested in creating ranking for different sparsity in a single run. For example, by having sparsities ranking from 0.1-0.9, instead of using tp.pruner.MagnitudePruner for several time (I suppose they might be different on different runs? or not, I am not quite sure here), can I use the single run Pruner to generate multiple sparsity maps?

Many thanks.

VainF commented 1 year ago

If I understand your question correctly, you are trying to prune different layers with different sparsity. This can be achieved by crafting a pruner with ch_sparsity_dict = {model.layer1: 0.3, model.block4: 0.2}. This allows us to prune blocks/layers with customized sparsity.

ChengHan111 commented 1 year ago

Hi I guess my question is when I saved some sparsities in state_dict (instead of using different sparities on layers, I saved different sparsities for single model), everytime running load_pruning_history would prune the model once more, resulted in negative convs. Any chance I can load the pruning history solely from the original model I build graph with, without regressively pruning the model? For example the following code will result in negative convs, what I am thinking is the loading_pruning_history function is based on the original model (before pruning), but I guess the logic here is to continue pruning on the pruned model? Any chances I can realize that, solely prune on original model.

state_dict = torch.load('pruned_model.pth')
DG = tp.DependencyGraph().build_dependency(model, example_inputs)
for epoch in range(args.num_train_epochs):
      for step, batch in enumerate(train_loader):
            for sparsity in sparsities:
                  sparsity = float(sparsity)
                  DG.load_pruning_history(state_dict[f'pruning[{sparsity}]'])
VainF commented 1 year ago

Hi @ChengHan111. It seems you're exploring the idea of dynamically switching between different sparsity levels, which is an interesting challenge. However, in torch-pruning, it's worth noting that the removal of parameters is irreversible, making this task a bit complex. Currently, we can only gradually repeat the pruning history on the same model, as follows:

https://github.com/VainF/Torch-Pruning/blob/37dd835bf2e5c361caae297deda8ce8aca66d091/torch_pruning/dependency.py#L292

For example, it's feasible to implement a new loader:

def load_pruning_history(self, pruning_history):
    """Redo the pruning history"""
    self._pruning_history = pruning_history
    for module_name, is_out_channel_pruning, pruning_idx in self._pruning_history:
        module = self.model
        for n in module_name.split('.'):
            module = getattr(module, n)
        pruner = self.get_pruner_of_module(module)
        if is_out_channel_pruning:
            pruning_fn = pruner.prune_out_channels
        else:
            pruning_fn = pruner.prune_in_channels
        group = self.get_pruning_group(module, pruning_fn, pruning_idx)
        group.prune(record_history=False)
        yield self.model # <= Add this line to yield intermediate models

Then we can do something like:

for step_i, model_i in enumerate(DG.load_pruning_history(history)):
    print(model_i)

But we still have to reinitialize a new model to start a new loop.

ChengHan111 commented 1 year ago

Got it. That is exactly what I am asking for! Many thanks!