VainF / Torch-Pruning

[CVPR 2023] Towards Any Structural Pruning; LLMs / SAM / Diffusion / Transformers / YOLOv8 / CNNs
https://arxiv.org/abs/2301.12900
MIT License
2.44k stars 309 forks source link

Falling in a endless loop when pruning the ViT model in CLIP #86

Open RongchangLi opened 1 year ago

RongchangLi commented 1 year ago

It is the test code:

import torch
import torch_pruning as tp
import torch.nn as nn
import clip.clip as clip

clip_model = clip.load('ViT-B/32', device='cpu', jit=False, )[0]
model = clip_model.visual.transformer.resblocks.cpu()
print(model)

ori_size = tp.utils.count_params(model)
example_inputs = torch.randn(1, 50,768)
imp = tp.importance.MagnitudeImportance(p=2) # L2 norm pruning
ignored_layers = []
for m in model.modules():
    if isinstance(m, torch.nn.Linear) and m.out_features == 1000:
        ignored_layers.append(m)
    if isinstance(m, torch.nn.Conv2d):
        ignored_layers.append(m)
    if isinstance(m, nn.modules.linear.NonDynamicallyQuantizableLinear):
        ignored_layers.append(m)  # this module is used in Self-Attention

total_steps = 1
pruner = tp.pruner.LocalMagnitudePruner(
    model,
    example_inputs,
    importance=imp,
    total_steps=total_steps, # number of iterations
    ch_sparsity=0.25, # channel sparsity
    ignored_layers=ignored_layers, # ignored_layers will not be pruned
)

for i in range(total_steps): # iterative pruning
    pruner.step()
    print(
        "  Params: %.2f M => %.2f M"
        % (ori_size / 1e6, tp.utils.count_params(model) / 1e6)
    )
torch.save(model, 'model.pth')

I find when perform pruner.step(), it falls in a endless loop to perform
_fix_dependency_graph_non_recursive(root_node, pruning_fn, idxs) It's the line 514 in torch_pruning.dependency.py

Is there something wrong with my settings? Or am I missing something?

RongchangLi commented 1 year ago

The bug is solved by updating the version to v1.0. However, there is another bug: when using _model=clipmodel.visual.cpu() rather than _model = clipmodel.visual.transformer.resblocks.cpu(). It also falls in an endless loop. Here is my code:

import torch
import torch_pruning as tp
import torch.nn as nn
import clip.clip as clip

clip_model = clip.load('ViT-B/32', device='cpu', jit=False, )[0]
model = clip_model.visual.cpu() #good running when using model = clip_model.visual.transformer.resblocks.cpu()
# print(model)

ori_size = tp.utils.count_params(model)
example_inputs = torch.randn(1,3,224,224)
imp = tp.importance.MagnitudeImportance(p=2) # L2 norm pruning
ignored_layers = []
for m in model.modules():
    if isinstance(m, torch.nn.Linear) and m.out_features == 1000:
        ignored_layers.append(m)
    # if isinstance(m, torch.nn.Conv2d):
    #     ignored_layers.append(m)
    if isinstance(m, nn.modules.linear.NonDynamicallyQuantizableLinear):
        ignored_layers.append(m)  # this module is used in Self-Attention

total_steps = 1
unwrapped_parameters = [model.class_embedding, model.positional_embedding]
pruner = tp.pruner.MagnitudePruner(
    model,
    example_inputs,
    global_pruning=False,
    importance=imp,
    iterative_steps=total_steps, # number of iterations
    ch_sparsity=0.75, # channel sparsity
    ignored_layers=ignored_layers, # ignored_layers will not be pruned
    unwrapped_parameters=unwrapped_parameters
)

for i in range(total_steps): # iterative pruning
    pruner.step()
    print(
        "  Params: %.2f M => %.2f M"
        % (ori_size / 1e6, tp.utils.count_params(model) / 1e6)
    )

When stop runing: its report is shown as following:

  File "D:\my_programming\2022\video_PETL\pruning_test.py", line 24, in <module>
    pruner = tp.pruner.MagnitudePruner(
  File "d:\my_programming\tools\torch-pruning\torch_pruning\pruner\algorithms\metapruner.py", line 62, in __init__
    self.DG = dependency.DependencyGraph().build_dependency(
  File "d:\my_programming\tools\torch-pruning\torch_pruning\dependency.py", line 246, in build_dependency
    self.module2node = self._trace(
  File "d:\my_programming\tools\torch-pruning\torch_pruning\dependency.py", line 530, in _trace
    n = stack.pop(-1)
KeyboardInterrupt

Could you help solve this problem?

Sarthak-22 commented 6 days ago

Hi, I am facing a similar issue. Were you able to solve it ?