VainF / Torch-Pruning

[CVPR 2023] DepGraph: Towards Any Structural Pruning
https://arxiv.org/abs/2301.12900
MIT License
2.71k stars 334 forks source link

Prunning AttentionPool2D #229

Open justlike-prog opened 1 year ago

justlike-prog commented 1 year ago

Hi I would like to prune the AttentionPool2D Layers in the ModifiedResnet of CLIP. For some reason I am running into this error:

ace_computational_graph(self, module2node, grad_fn_root, gradfn2module, reused, visited)
    835     visited.add(grad_fn)
    838 for (param, dim) in self.unwrapped_parameters:
--> 839     module2node[param].pruning_dim = dim
    840 return module2node

KeyError: Parameter containing:
tensor([[-0.0744, -0.0718, -0.0343,  ..., -0.0477, -0.0331, -0.0632],
        [-0.0628, -0.0317, -0.0034,  ..., -0.0384,  0.0434, -0.0983],
        [-0.0248, -0.0110,  0.0150,  ..., -0.0453,  0.0003, -0.0498],
        ...,
        [-0.0109, -0.0399, -0.0177,  ..., -0.0263, -0.0504, -0.0535],
        [-0.0262, -0.0142, -0.0151,  ..., -0.0419, -0.0293, -0.0435],
        [-0.0088, -0.0109, -0.0126,  ..., -0.0090, -0.0046, -0.0178]],
       requires_grad=True)

My code looks the following way:

import torch
import torch.nn as nnimport torch_pruning as tp
from torchvision.models import resnet18
from mmedit.models.backbones.sr_backbones.coopclipiqa import load_clip_to_cpu
from mmedit.models.components.clip.model import AttentionPool2d

backbone_name='RN50' # 'ViT-B/32'
clip_model = load_clip_to_cpu(backbone_name)
clip_model.eval()

model = clip_model.visual

imp = tp.importance.MagnitudeImportance(p=2, group_reduction="mean")
base_macs, base_params = tp.utils.count_ops_and_params(model, example_inputs)
channel_groups = {}
example_inputs = torch.randn(1, 3, 224, 224)

for m in model.modules():
    if isinstance(m, AttentionPool2d):

        channel_groups[m.q_proj] = m.num_heads
        channel_groups[m.k_proj] = m.num_heads
        channel_groups[m.v_proj] = m.num_heads
        channel_groups[m.c_proj] = m.num_heads

pruner = tp.pruner.MagnitudePruner(
    model, 
    example_inputs, 
    global_pruning=False, # If False, a uniform sparsity will be assigned to different layers.
    importance=imp, # importance criterion for parameter selection
    iterative_steps=1, # the number of iterations to achieve target sparsity
    ch_sparsity=0.5,
    channel_groups=channel_groups,
    # output_transform=lambda out: out.logits.sum(),
    # ignored_layers=[model.classifier],
)
Zaragoto commented 3 months ago

Hi, I also meet a problem when I try to prune the ModifiedResNet50 in CLIP model, and the issue comes from the attnpool module too. Have you solved your problem? If yes, could you please share your solution? Thanks a lot!