VainF / Torch-Pruning

[CVPR 2023] Towards Any Structural Pruning; LLMs / SAM / Diffusion / Transformers / YOLOv8 / CNNs
https://arxiv.org/abs/2301.12900
MIT License
2.44k stars 308 forks source link

Cannot calculate importance score for a single group #371

Open MartinZakhaev opened 2 months ago

MartinZakhaev commented 2 months ago

Hello,

I followed the example and come up with this:

model = resnet18(pretrained=True).eval()
example_inputs = torch.randn(1, 3, 224, 224)

DG = tp.DependencyGraph().build_dependency(model, example_inputs=example_inputs) 
group = DG.get_pruning_group( model.conv1, tp.prune_conv_out_channels, idxs=[2, 6, 9] )    

scorer = tp.importance.MagnitudeImportance()    
imp_score = scorer(group)
min_score = imp_score.min() 

But I always got this error saying:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[36], [line 4](vscode-notebook-cell:?execution_count=36&line=4)
      [2](vscode-notebook-cell:?execution_count=36&line=2) group = DG.get_pruning_group( model.conv1, tp.prune_conv_out_channels, idxs=[2, 6, 9] )    
      [3](vscode-notebook-cell:?execution_count=36&line=3) scorer = tp.importance.MagnitudeImportance()    
----> [4](vscode-notebook-cell:?execution_count=36&line=4) imp_score = scorer(group)    
      [5](vscode-notebook-cell:?execution_count=36&line=5) #imp_score is a 1-D tensor with length 3 for channels [2, 6, 9]  
      [6](vscode-notebook-cell:?execution_count=36&line=6) min_score = imp_score.min() 

File [c:\Users\ANDIKA](file:///C:/Users/ANDIKA) WAHYU\AppData\Local\Programs\Python\Python310\lib\site-packages\torch\utils\_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    [112](file:///C:/Users/ANDIKA%20WAHYU/AppData/Local/Programs/Python/Python310/lib/site-packages/torch/utils/_contextlib.py:112) @functools.wraps(func)
    [113](file:///C:/Users/ANDIKA%20WAHYU/AppData/Local/Programs/Python/Python310/lib/site-packages/torch/utils/_contextlib.py:113) def decorate_context(*args, **kwargs):
    [114](file:///C:/Users/ANDIKA%20WAHYU/AppData/Local/Programs/Python/Python310/lib/site-packages/torch/utils/_contextlib.py:114)     with ctx_factory():
--> [115](file:///C:/Users/ANDIKA%20WAHYU/AppData/Local/Programs/Python/Python310/lib/site-packages/torch/utils/_contextlib.py:115)         return func(*args, **kwargs)

File [c:\Users\ANDIKA](file:///C:/Users/ANDIKA) WAHYU\AppData\Local\Programs\Python\Python310\lib\site-packages\torch_pruning\pruner\importance.py:293, in GroupNormImportance.__call__(self, group)
    [290](file:///C:/Users/ANDIKA%20WAHYU/AppData/Local/Programs/Python/Python310/lib/site-packages/torch_pruning/pruner/importance.py:290) if len(group_imp) == 0: # skip groups without parameterized layers
    [291](file:///C:/Users/ANDIKA%20WAHYU/AppData/Local/Programs/Python/Python310/lib/site-packages/torch_pruning/pruner/importance.py:291)     return None
--> [293](file:///C:/Users/ANDIKA%20WAHYU/AppData/Local/Programs/Python/Python310/lib/site-packages/torch_pruning/pruner/importance.py:293) group_imp = self._reduce(group_imp, group_idxs)
    [294](file:///C:/Users/ANDIKA%20WAHYU/AppData/Local/Programs/Python/Python310/lib/site-packages/torch_pruning/pruner/importance.py:294) group_imp = self._normalize(group_imp, self.normalizer)
    [295](file:///C:/Users/ANDIKA%20WAHYU/AppData/Local/Programs/Python/Python310/lib/site-packages/torch_pruning/pruner/importance.py:295) return group_imp

File [c:\Users\ANDIKA](file:///C:/Users/ANDIKA) WAHYU\AppData\Local\Programs\Python\Python310\lib\site-packages\torch_pruning\pruner\importance.py:180, in GroupNormImportance._reduce(self, group_imp, group_idxs)
    [178](file:///C:/Users/ANDIKA%20WAHYU/AppData/Local/Programs/Python/Python310/lib/site-packages/torch_pruning/pruner/importance.py:178) if self.group_reduction == "sum" or self.group_reduction == "mean":
    [179](file:///C:/Users/ANDIKA%20WAHYU/AppData/Local/Programs/Python/Python310/lib/site-packages/torch_pruning/pruner/importance.py:179)     debug_file.write("Adding importance using scatter_add_\n")
--> [180](file:///C:/Users/ANDIKA%20WAHYU/AppData/Local/Programs/Python/Python310/lib/site-packages/torch_pruning/pruner/importance.py:180)     reduced_imp.scatter_add_(0, torch.tensor(root_idxs, device=imp.device), imp)  # accumulated importance
    [181](file:///C:/Users/ANDIKA%20WAHYU/AppData/Local/Programs/Python/Python310/lib/site-packages/torch_pruning/pruner/importance.py:181) elif self.group_reduction == "max":
    [182](file:///C:/Users/ANDIKA%20WAHYU/AppData/Local/Programs/Python/Python310/lib/site-packages/torch_pruning/pruner/importance.py:182)     # keep the max importance
    [183](file:///C:/Users/ANDIKA%20WAHYU/AppData/Local/Programs/Python/Python310/lib/site-packages/torch_pruning/pruner/importance.py:183)     selected_imp = torch.index_select(reduced_imp, 0, torch.tensor(root_idxs, device=imp.device))

RuntimeError: index 6 is out of bounds for dimension 0 with size 3

what is wrong? i already doing some tracing into torch-pruning code, but to no avail. please help me

MartinZakhaev commented 2 months ago

Any kind of help is appreciated @VainF