microsoft / nni

An open source AutoML toolkit for automate machine learning lifecycle, including feature engineering, neural architecture search, model compression and hyper-parameter tuning.
https://nni.readthedocs.io
MIT License
14.02k stars 1.81k forks source link

yolov5prune error #5301

Open Turing77 opened 1 year ago

Turing77 commented 1 year ago

Describe the issue: I'm trying to prune the pre-trained model yolov5n-0.5 from Yolov5-face. Here is the code I used:

import torch, torchvision
from nni.algorithms.compression.v2.pytorch.pruning import L1NormPruner, L2NormPruner,FPGMPruner,ActivationAPoZRankPruner
from nni.compression.pytorch.speedup import ModelSpeedup
from rich import print
from utils.general import check_img_size
from models.common import Conv
from models.experimental import attempt_load
from models.yolo import Detect
from utils.activations import SiLU
import torch.nn as nn
from nni.compression.pytorch.utils.counter import count_flops_params

class SiLU(nn.Module):  # export-friendly version of nn.SiLU()
    @staticmethod
    def forward(x):
        return x * torch.sigmoid(x)

device = device = torch.device("cuda:1")
model = attempt_load('/data03/hezhenhui/project/helmet/yolov5-6.0/runs/train/helmet6/weights/best.pt', map_location=device, inplace=True, fuse=True) # load FP32 model
model.eval()

for k, m in model.named_modules():
    if isinstance(m, Conv): # assign export-friendly activations
        if isinstance(m.act, nn.SiLU):
            m.act = SiLU()
        elif isinstance(m, Detect):
            m.inplace = False
    m.onnx_dynamic = False
    if hasattr(m, 'forward_export'):
        m.forward = m.forward_export # assign custom forward (optional)

imgsz = (640, 640)
imgsz *= 2 if len(imgsz) == 1 else 1 # expand

gs = int(max(model.stride)) # grid size (max stride)
imgsz = [check_img_size(x, gs) for x in imgsz] # verify img_size are gs-multiples
im = torch.zeros(1, 3, *imgsz).to(device) # image size(1,3,320,192) BCHW iDetection
dummy_input = im

cfg_list = [{
'sparsity': 0.3, 'op_types': ['Conv2d'],'op_names': [
    'model.0.conv',
    'model.1.conv',
    'model.2.cv1.conv',
    'model.2.cv2.conv',
    'model.2.cv3.conv',
    'model.2.m.0.cv1.conv',
    'model.2.m.0.cv2.conv',
    'model.2.m.1.cv1.conv',
    'model.2.m.1.cv2.conv',
    'model.2.m.2.cv1.conv',
    'model.2.m.2.cv2.conv',
    'model.2.m.3.cv1.conv',
    'model.2.m.3.cv2.conv',
    'model.3.conv',
    'model.4.cv1.conv',
    'model.4.cv2.conv',
    'model.4.cv3.conv',
    'model.4.m.0.cv1.conv',
    'model.4.m.0.cv2.conv',
    'model.4.m.1.cv1.conv',
    'model.4.m.1.cv2.conv',
    'model.4.m.2.cv1.conv',
    'model.4.m.2.cv2.conv',
    'model.4.m.3.cv1.conv',
    'model.4.m.3.cv2.conv',
    'model.4.m.4.cv1.conv',
    'model.4.m.4.cv2.conv',
    'model.4.m.5.cv1.conv',
    'model.4.m.5.cv2.conv',
    'model.4.m.6.cv1.conv',
    'model.4.m.6.cv2.conv',
    'model.4.m.7.cv1.conv',
    'model.4.m.7.cv2.conv',
    'model.5.conv',
    'model.6.cv1.conv',
    'model.6.cv2.conv',
    'model.6.cv3.conv',
    'model.6.m.0.cv1.conv',
    'model.6.m.0.cv2.conv',
    'model.6.m.1.cv1.conv',
    'model.6.m.1.cv2.conv',
    'model.6.m.2.cv1.conv',
    'model.6.m.2.cv2.conv',
    'model.6.m.3.cv1.conv',
    'model.6.m.3.cv2.conv',
    'model.6.m.4.cv1.conv',
    'model.6.m.4.cv2.conv',
    'model.6.m.5.cv1.conv',
    'model.6.m.5.cv2.conv',
    'model.6.m.6.cv1.conv',
    'model.6.m.6.cv2.conv',
    'model.6.m.7.cv1.conv',
    'model.6.m.7.cv2.conv',
    'model.6.m.8.cv1.conv',
    'model.6.m.8.cv2.conv',
    'model.6.m.9.cv1.conv',
    'model.6.m.9.cv2.conv',
    'model.6.m.10.cv1.conv',
    'model.6.m.10.cv2.conv',
    'model.6.m.11.cv1.conv',
    'model.6.m.11.cv2.conv',
    'model.7.conv',
    'model.8.cv1.conv',
    'model.8.cv2.conv',
    'model.8.cv3.conv',
    'model.8.m.0.cv1.conv',
    'model.8.m.0.cv2.conv',
    'model.8.m.1.cv1.conv',
    'model.8.m.1.cv2.conv',
    'model.8.m.2.cv1.conv',
    'model.8.m.2.cv2.conv',
    'model.8.m.3.cv1.conv',
    'model.8.m.3.cv2.conv',
    'model.9.cv1.conv',
    'model.9.cv2.conv',
    'model.10.conv',
    'model.13.cv1.conv',
    'model.13.cv2.conv',
    'model.13.cv3.conv',
    'model.13.m.0.cv1.conv',
    'model.13.m.0.cv2.conv',
    'model.13.m.1.cv1.conv',
    'model.13.m.1.cv2.conv',
    'model.13.m.2.cv1.conv',
    'model.13.m.2.cv2.conv',
    'model.13.m.3.cv1.conv',
    'model.13.m.3.cv2.conv',
    'model.14.conv',
    'model.17.cv1.conv',
    'model.17.cv2.conv',
    'model.17.cv3.conv',
    'model.17.m.0.cv1.conv',
    'model.17.m.0.cv2.conv',
    'model.17.m.1.cv1.conv',
    'model.17.m.1.cv2.conv',
    'model.17.m.2.cv1.conv',
    'model.17.m.2.cv2.conv',
    'model.17.m.3.cv1.conv',
    'model.17.m.3.cv2.conv',
    'model.18.conv',
    'model.20.cv1.conv',
    'model.20.cv2.conv',
    'model.20.cv3.conv',
    'model.20.m.0.cv1.conv',
    'model.20.m.0.cv2.conv',
    'model.20.m.1.cv1.conv',
    'model.20.m.1.cv2.conv',
    'model.20.m.2.cv1.conv',
    'model.20.m.2.cv2.conv',
    'model.20.m.3.cv1.conv',
    'model.20.m.3.cv2.conv',
    'model.21.conv',
    'model.23.cv1.conv',
    'model.23.cv2.conv',
    'model.23.cv3.conv',
    'model.23.m.0.cv1.conv',
    'model.23.m.0.cv2.conv',
    'model.23.m.1.cv1.conv',
    'model.23.m.1.cv2.conv',
    'model.23.m.2.cv1.conv',
    'model.23.m.2.cv2.conv',
    'model.23.m.3.cv1.conv',
    'model.23.m.3.cv2.conv'
    ]
},
{
'op_names':['model.24.m.0','model.24.m.1','model.24.m.2'],
'exclude': True
    }
]

pruner = L1NormPruner(model, cfg_list)
_, masks = pruner.compress()
# print(masks)
pruner.export_model(model_path='helmet_yolov5s.pt', mask_path='helmet_mask.pt')
pruner.show_pruned_weights()
pruner._unwrap_model()

print("im.shape:",dummy_input.shape)

But it always throws this error:

ERROR: Tensor-valued Constant nodes differed in value across invocations. This often indicates that the tracer has encountered untraceable code.
        Node:
                %864 : Tensor = prim::Constant[value={2}](), scope: __module.model.24 # /data03/hezhenhui/project/helmet/yolov5-6.0/models/yolo.py:66:0
        Source Location:
                /data03/hezhenhui/project/helmet/yolov5-6.0/models/yolo.py(66): forward
                /data03/hezhenhui/.conda/envs/tdn/lib/python3.8/site-packages/torch/nn/modules/module.py(709): _slow_forward
                /data03/hezhenhui/.conda/envs/tdn/lib/python3.8/site-packages/torch/nn/modules/module.py(725): _call_impl
                /data03/hezhenhui/project/helmet/yolov5-6.0/models/yolo.py(149): _forward_once
                /data03/hezhenhui/project/helmet/yolov5-6.0/models/yolo.py(126): forward
                /data03/hezhenhui/.conda/envs/tdn/lib/python3.8/site-packages/torch/nn/modules/module.py(709): _slow_forward
                /data03/hezhenhui/.conda/envs/tdn/lib/python3.8/site-packages/torch/nn/modules/module.py(725): _call_impl
                /data03/hezhenhui/.conda/envs/tdn/lib/python3.8/site-packages/torch/jit/_trace.py(934): trace_module
                /data03/hezhenhui/.conda/envs/tdn/lib/python3.8/site-packages/torch/jit/_trace.py(733): trace
                /data03/hezhenhui/.conda/envs/tdn/lib/python3.8/site-packages/nni/common/graph_utils.py(91): _trace
                /data03/hezhenhui/.conda/envs/tdn/lib/python3.8/site-packages/nni/common/graph_utils.py(67): __init__
                /data03/hezhenhui/.conda/envs/tdn/lib/python3.8/site-packages/nni/common/graph_utils.py(265): __init__
                /data03/hezhenhui/.conda/envs/tdn/lib/python3.8/site-packages/nni/common/graph_utils.py(25): build_module_graph
                /data03/hezhenhui/.conda/envs/tdn/lib/python3.8/site-packages/nni/compression/pytorch/speedup/compressor.py(73): __init__
                prune_nni.py(242): <module>
        Comparison exception:   expand(torch.cuda.FloatTensor{[1, 3, 40, 40, 2]}, size=[]): the number of sizes provided (0) must be greater or equal to the number of dimensions in the tensor (5)

I can't find a solution to the problem, can you give some advice

Environment:

J-shang commented 1 year ago

@super-dainiu please help to fix this issue

Lijiaoa commented 1 year ago

If have a fixed PR, please link it with this issue. Thanks~ @super-dainiu