VainF / Torch-Pruning

[CVPR 2023] Towards Any Structural Pruning; LLMs / SAM / Diffusion / Transformers / YOLOv8 / CNNs
MIT License
2.6k stars 321 forks source link

Channel Sorting #335

Open satabios opened 7 months ago

satabios commented 7 months ago

@VainF As a feature request. Would it be possible to apply channel sorting based on channels? If you can provide any insights it would be really helpful.

 # function to sort the channels from important to non-important
def get_input_channel_importance(weight):
    in_channels = weight.shape[1]
    importances = []
    # compute the importance for each input channel
    for i_c in range(weight.shape[1]):
        channel_weight = weight.detach()[:, i_c]

        importance = torch.norm(channel_weight)


def apply_channel_sorting(model):
    model = copy.deepcopy(model)  # do not modify the original model
    # fetch all the conv and bn layers from the backbone
    all_convs = [m for m in model.backbone if isinstance(m, nn.Conv2d)]
    all_bns = [m for m in model.backbone if isinstance(m, nn.BatchNorm2d)]
    # iterate through conv layers
    for i_conv in range(len(all_convs) - 1):
        # each channel sorting index, we need to apply it to:
        # - the output dimension of the previous conv
        # - the previous BN layer
        # - the input dimension of the next conv (we compute importance here)
        prev_conv = all_convs[i_conv]
        prev_bn = all_bns[i_conv]
        next_conv = all_convs[i_conv + 1]
        # note that we always compute the importance according to input channels
        importance = get_input_channel_importance(next_conv.weight)
        # sorting from large to small
        sort_idx = torch.argsort(importance, descending=True)

        # apply to previous conv and its following bn
            prev_conv.weight.detach(), 0, sort_idx))
        for tensor_name in ['weight', 'bias', 'running_mean', 'running_var']:
            tensor_to_apply = getattr(prev_bn, tensor_name)
                torch.index_select(tensor_to_apply.detach(), 0, sort_idx)

        # apply to the next conv input (hint: one line of code)

            torch.index_select(next_conv.weight.detach(), 1, sort_idx))

    return model
VainF commented 7 months ago

Hi @satabios, may I ask when we can use this feature? It would be great if there were some publications.

satabios commented 7 months ago

Found the paper: Also it has been shown in various other papers and has also been tested here: