VainF / Torch-Pruning

[CVPR 2023] DepGraph: Towards Any Structural Pruning
https://arxiv.org/abs/2301.12900
MIT License
2.75k stars 337 forks source link

IndexError: index 1536 is out of bounds for dimension 0 with size 1024 #412

Open broken-dream opened 3 months ago

broken-dream commented 3 months ago

I'm trying to prune a ViT model implemented in vit_pytorch but got following error:

Traceback (most recent call last):
  File "/home/wh/generative_action/SynHSI/test_prune.py", line 30, in <module>
    for g in pruner.step(interactive=True):
  File "/home/wh/miniconda3/envs/hsi-torch-dev/lib/python3.10/site-packages/torch_pruning/pruner/algorithms/metapruner.py", line 411, in _prune
    imp = self.estimate_importance(group) # raw importance score
  File "/home/wh/miniconda3/envs/hsi-torch-dev/lib/python3.10/site-packages/torch_pruning/pruner/algorithms/metapruner.py", line 279, in estimate_importance
    return self.importance(group)
  File "/home/wh/miniconda3/envs/hsi-torch-dev/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/home/wh/miniconda3/envs/hsi-torch-dev/lib/python3.10/site-packages/torch_pruning/pruner/importance.py", line 221, in __call__
    local_imp = local_imp[idxs]
IndexError: index 1536 is out of bounds for dimension 0 with size 1024

This is a minimal example to reproduce the error:

from vit_pytorch import ViT
import torch_pruning as tp
import torch

inputs = torch.randn(1, 64, 32, 32, dtype=torch.float32)

model = ViT(
    image_size=32,
    patch_size=8,
    channels=64,
    num_classes=512,
    dim=512,
    depth=6,
    heads=16,
    mlp_dim=1024,
    dropout=0.1,
    emb_dropout=0.1
)

importance = tp.importance.MagnitudeImportance(p=1)
pruner = tp.pruner.MetaPruner(
    model,
    inputs,
    importance=importance,
    global_pruning=False,
    pruning_ratio=0.5,
    prune_head_dims=True,
    prune_num_heads=False
)

for g in pruner.step(interactive=True):
    g.prune()

From other similar issues, it seems this problem is caused by the torch.split() operation. But I didn't find any usage of torch.split() in vit_pytorch.

I also tried to locate the problem by using RandomImportance mentioned in #147, but it works well when using RandomImportance.

BTW, when I used RandomImportance to prune the model, I found that the parameter number of the model decreased but the inference time increased. I have no experience in model pruning so I have no idea if it's a normal phenomenon. But Intuitively fewer parameter should mean less time cost?

VainF commented 3 months ago

Hi! Please check this example: https://github.com/VainF/Torch-Pruning/blob/d7e23ed28dded2b6208074977f18b6302bb8a46e/examples/transformers/prune_hf_vit.py#L102 A num_heads parameter is required for transformer pruning.