VainF / Torch-Pruning

[CVPR 2023] DepGraph: Towards Any Structural Pruning
https://arxiv.org/abs/2301.12900
MIT License
2.69k stars 331 forks source link

Inconsistency in `group.prune()`? #272

Open nathanhubens opened 1 year ago

nathanhubens commented 1 year ago

There seems to be a inconsistency in the behavior of groupe.prune() when we provide it the pruning_idxs as an argument vs when updating the group with group = self.DG.get_pruning_group(module, pruning_fn, pruning_idxs.tolist())

To replicate the issue, please find below simplified but working versions of your code:

Calling group.prune(pruning_idxs)

def step(self, interactive=False)-> typing.Union[typing.Generator, None]:
    self.current_step += 1
    pruning_method = self.prune_global if self.global_pruning else self.prune_local

    for group in self.DG.get_all_groups(ignored_layers=self.ignored_layers, root_module_types=self.root_module_types):
        pruning_idxs = self.prune_local(group)
        group.prune(pruning_idxs)

def prune_local(self, group) -> typing.Generator:
        module = group[0][0].target.module
        pruning_fn = group[0][0].handler
        imp = self.estimate_importance(group)
        current_channels = self.DG.get_out_channels(module)
        target_sparsity = self.get_target_sparsity(module)
        n_pruned = current_channels - int(
            self.layer_init_out_ch[module] *
            (1 - target_sparsity)
        )

        imp_argsort = torch.argsort(imp)
        pruning_idxs = imp_argsort[:n_pruned]

        return pruning_idxs

Calling group.prune()


def step(self, interactive=False)-> typing.Union[typing.Generator, None]:
    self.current_step += 1
    pruning_method = self.prune_global if self.global_pruning else self.prune_local

    for group in pruning_method():
        group.prune()

def prune_local(self, group) -> typing.Generator:
    for group in self.DG.get_all_groups(ignored_layers=self.ignored_layers, root_module_types=self.root_module_types):
        module = group[0][0].target.module
        pruning_fn = group[0][0].handler
        imp = self.estimate_importance(group)
        current_channels = self.DG.get_out_channels(module)
        target_sparsity = self.get_target_sparsity(module)
        n_pruned = current_channels - int(
            self.layer_init_out_ch[module] *
            (1 - target_sparsity)
        )

        imp_argsort = torch.argsort(imp)
        pruning_idxs = imp_argsort[:n_pruned]

        group = self.DG.get_pruning_group(module, pruning_fn, pruning_idxs.tolist())

        yield group

Both were tried on a vgg16_bnmodel:

model = vgg16_bn(pretrained=False)
example_inputs = torch.randn(1, 3, 224, 224)

imp = tp.importance.MagnitudeImportance()

pruner = MetaPruner(
    model,
    example_inputs,
    importance=imp,
    ch_sparsity=0.5,
    global_pruning=False, 
    root_module_types=[nn.Conv2d]
)

pruner.step()

and with torch-pruning==1.2.5

VainF commented 1 year ago

Hi @nathanhubens, thanks for reaching out. Could you provide more details about the inconsistency? The following code works well on my side:

import torch
from torchvision.models import resnet18
import torch_pruning as tp

def test_depgraph():
    model = resnet18(pretrained=True).eval()
    # 1. build dependency graph for resnet18
    DG = tp.DependencyGraph()
    DG.build_dependency(model, example_inputs=torch.randn(1,3,224,224))
    # 2. Select channels for pruning, here we prune the channels indexed by [2, 6, 9].
    pruning_idxs = pruning_idxs=[2, 6, 9]
    pruning_group = DG.get_pruning_group( model.conv1, tp.prune_conv_out_channels, idxs=pruning_idxs )
    pruning_group.prune()
    affected_weights1 = []
    for dep, _ in pruning_group:
        module = dep.target.module
        if hasattr(module, 'weight'):
            affected_weights1.append(module.weight.detach())
        if hasattr(module, 'bias') and module.bias is not None:
            affected_weights1.append(module.bias.detach())

    model = resnet18(pretrained=True).eval()
    # 1. build dependency graph for resnet18
    DG = tp.DependencyGraph()
    DG.build_dependency(model, example_inputs=torch.randn(1,3,224,224))
    # 2. Select channels for pruning
    pruning_idxs = pruning_idxs=[1, 2, 3, 4] # we will replace it with [2,6,9]
    pruning_group = DG.get_pruning_group( model.conv1, tp.prune_conv_out_channels, idxs=pruning_idxs )
    pruning_group.prune([2,6,9])
    affected_weights2 = []
    for dep, _ in pruning_group:
        module = dep.target.module
        if hasattr(module, 'weight'):
            affected_weights2.append(module.weight.detach())
        if hasattr(module, 'bias') and module.bias is not None:
            affected_weights2.append(module.bias.detach())

    for w1, w2 in zip(affected_weights1, affected_weights2):
        assert torch.allclose(w1, w2)

if __name__=='__main__':
    test_depgraph()
nathanhubens commented 1 year ago

I have created a Colab that probably better illustrates what I am trying to explain.

Maybe there is something that I am missing, but the examples I show there seem to behave differently, especially for the last layer.

VainF commented 1 year ago

Thank you for the Colab. I also noticed the same inconsistency but haven't resolved it yet. It's a little confusing. I will go back to this issue after the CVPR deadline.