VainF / Torch-Pruning

[CVPR 2023] Towards Any Structural Pruning; LLMs / SAM / Diffusion / Transformers / YOLOv8 / CNNs
https://arxiv.org/abs/2301.12900
MIT License
2.64k stars 329 forks source link

Pruning MViT #287

Open grigorn opened 10 months ago

grigorn commented 10 months ago

Hi. I am trying to prune MViTv2 with random pruning with the following code

import torch
import timm
from timm.models.mvitv2 import MultiScaleAttention, MultiScaleVit as MViT
import torch_pruning as tp

model = timm.create_model('mvitv2_base.fb_in1k', pretrained=False)
example_inputs = torch.randn(1, 3, 224, 224)

imp = tp.importance.RandomImportance()
ch_groups, num_heads = dict(), dict()
for m in model.modules():
    if isinstance(m, MultiScaleAttention):
        num_heads[m.qkv] = m.num_heads
        ch_groups[m.pool_q] = m.pool_q.groups
        ch_groups[m.pool_k] = m.pool_k.groups
        ch_groups[m.pool_v] = m.pool_v.groups

ignored_layers = [model.head]

pruner = tp.pruner.MetaPruner(
    model, 
    example_inputs, 
    global_pruning=False,
    importance=imp,
    iterative_steps=1,
    pruning_ratio=0.5,
    ignored_layers=ignored_layers,
    out_channel_groups=ch_groups,
    num_heads = num_heads,
    prune_head_dims=False,
    prune_num_heads=True,
    head_pruning_ratio=0.5, 

)
pruner.step()

for m in model.modules():
    if isinstance(m, MultiScaleAttention):
        if num_heads[m.qkv] == 0:
            m.num_heads = 1
        else:
            m.num_heads = num_heads[m.qkv]
        m.dim_out = m.head_dim * m.num_heads

Forward pass after pruning does not work. I think num_heads and pruned dimensions is not being calculated and updated correctly. Initially, num_heads is

 Linear(in_features=96, out_features=288, bias=True): 1,
 Linear(in_features=96, out_features=288, bias=True): 1,
 Linear(in_features=96, out_features=576, bias=True): 2,
 Linear(in_features=192, out_features=576, bias=True): 2,
 Linear(in_features=192, out_features=576, bias=True): 2,
 Linear(in_features=192, out_features=1152, bias=True): 4,
 Linear(in_features=384, out_features=1152, bias=True): 4,
 Linear(in_features=384, out_features=1152, bias=True): 4,

After pruning it becomes this. Layers with 1 head became layers with 0 heads. Also, it has subtracted some number from out_features.

 Linear(in_features=48, out_features=285, bias=True): 0,
 Linear(in_features=48, out_features=285, bias=True): 0,
 Linear(in_features=48, out_features=573, bias=True): 1,
 Linear(in_features=96, out_features=573, bias=True): 1,
 Linear(in_features=96, out_features=573, bias=True): 1,
 Linear(in_features=96, out_features=1146, bias=True): 2,
 Linear(in_features=192, out_features=1146, bias=True): 2,
 Linear(in_features=192, out_features=1146, bias=True): 2,

torch_pruning 1.3.2

VainF commented 10 months ago

Hi @grigorn. Thanks for the issue and example. I will take a look at this bug after the CVPR deadline.

grigorn commented 9 months ago

Hi @VainF did you look at this issue?

patelashutosh commented 4 months ago

Hi @VainF, we are also facing similar issue for models which have attention_head_dim as any number greater than 1. The pruned dimensions do not work for forward pass for Attention blocks.