VainF / Torch-Pruning

[CVPR 2023] Towards Any Structural Pruning; LLMs / SAM / Diffusion / Transformers / YOLOv8 / CNNs
https://arxiv.org/abs/2301.12900
MIT License
2.44k stars 309 forks source link

Fails Pruning on Mobilenet v2 and v3 #328

Open satabios opened 5 months ago

satabios commented 5 months ago

I'm trying to apply Channelwise Pruning. There persists few issues,

  1. However, a few layers were not able to be pruned; the list of layers is printed out via the try and except call!
  2. While loading the pruned model, the pruned model does not seem to be holding the same weights as the original model [code not shown here]. What is the best practice to make sure the unpruned weights get carried over?

    mobilenet_v3 = torch.hub.load('pytorch/vision:v0.10.0', 'mobilenet_v3_small', pretrained=True) pruned_model = copy.deepcopy(mobilenet_v3) DG = tp.DependencyGraph().build_dependency(pruned_model, example_inputs=test_input)

    for group in DG.get_all_groups(ignored_layers=[], root_module_types=[nn.Conv2d, nn.Linear]):

    dep = group[0][0]
    layer_name = dep.target._name
    layer = dep.target.module
    new_layer_name = reformat_layer_name(layer_name)
    
    if isinstance(layer, torch.nn.Conv2d):
    
        out_channels = layer.out_channels
        in_channels = layer.in_channels
        groups = layer.groups
    
        if groups == in_channels:
            type_conv_prune = tp.prune_depthwise_conv_in_channels
            channels_to_prune = in_channels
        else:
            type_conv_prune = tp.prune_conv_out_channels
            channels_to_prune = out_channels
    
        result_array = generate_random_array(channels_to_prune)
        prune_layer = eval('pruned_model.' + str(new_layer_name))
    
        internal_group = DG.get_pruning_group(prune_layer, type_conv_prune, idxs=result_array)
    
        if DG.check_pruning_group(internal_group):
            internal_group.prune()
    
        try:
            _ = pruned_model(test_input)
            DG = tp.DependencyGraph().build_dependency(pruned_model, example_inputs=test_input)
            made_it += 1
    
        except:
            anomalous_layers.append("model." + new_layer_name)
            print("Ignore:", new_layer_name, type_conv_prune)

The major problem persists in conv group pruning where:

mbv2_model = copy.deepcopy(mobilenet_v2)
import torch_pruning as tp

for layer in conv_list:
    DG = tp.DependencyGraph().build_dependency(mbv2_model, example_inputs=torch.randn(1,3,32,32))
    # 2. Group coupled layers for model.conv1
    layer_for_pruning = eval(layer.replace('model','mbv2_model'))
    #Randonly prune few layers

    group = DG.get_pruning_group( layer_for_pruning , tp.prune_conv_out_channels, idxs=torch.randint(low=0, high=10, size=(5,)))

    # 3. Prune grouped layers altogether
    if DG.check_pruning_group(group): # avoid full pruning, i.e., channels=0.
        group.prune()
    _  = mbv2_model(torch.randn(1,3,32,32))
    print(layer)

# 4. Save & Load
model.zero_grad() # clear gradients
torch.save(model, 'model.pth') # We can not use .state_dict as the model structure is changed.
model = torch.load('model.pth')

RuntimeError: Given groups=27, expected weight to be divisible by 27 at dimension 0, but got weight of size [[32, 1, 3, 3]] instead
satabios commented 5 months ago

@VainF can you comment on why the pruning fails on mobilenets?