Gumpest / YOLOv5-Multibackbone-Compression

YOLOv5 Series Multi-backbone(TPH-YOLOv5, Ghostnet, ShuffleNetv2, Mobilenetv3Small, EfficientNetLite, PP-LCNet, SwinTransformer YOLO), Module(CBAM, DCN), Pruning (EagleEye, Network Slimming), Quantization (MQBench) and Deployment (TensorRT, ncnn) Compression Tool Box.
989 stars 201 forks source link

Use ConvTranspose2d instead of Upsample #100

Open dengxiongshi opened 4 months ago

dengxiongshi commented 4 months ago

@Gumpest 你好,支持把yolov5s-pruning.yaml中的nn.Upsample替换成nn.ConvTranspose2d进行prune吗?我进行替换后,按照给的文档训练一遍模型后,运行pruneEagleEye.py报错:

File "/root/.pycharm_helpers/pydev/pydevd.py", line 1491, in _exec
    pydev_imports.execfile(file, globals, locals)  # execute the script
  File "/root/.pycharm_helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
    exec(compile(contents+"\n", file, 'exec'), glob, loc)
  File "/data/yolov5/pruneEagleEye.py", line 153, in <module>
    rand_prune_and_eval(model, ignore_idx, opt)
  File "/data/yolov5/pruneEagleEye.py", line 65, in rand_prune_and_eval
    compact_model = Model(pruned_yaml, pruning=False).to(device)
  File "/data/yolov5/models/yolo_prune.py", line 325, in __init__
    m.stride = torch.tensor([s / x.shape[-2] for x in forward(torch.zeros(1, ch, s, s))])  # forward
  File "/data/yolov5/models/yolo_prune.py", line 324, in <lambda>
    forward = lambda x: self.forward(x)[0] if isinstance(m, Segment) else self.forward(x)
  File "/data/yolov5/models/yolo_prune.py", line 340, in forward
    return self._forward_once(x, profile, visualize)  # single-scale inference, train
  File "/data/yolov5/models/yolo_prune.py", line 250, in _forward_once
    x = m(x)  # run
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/container.py", line 139, in forward
python-BaseException
    input = module(input)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/data/yolov5/models/common.py", line 805, in forward
    return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1))
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/data/yolov5/models/common.py", line 90, in forward
    return self.act(self.bn(self.conv(x)))
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 443, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 439, in _conv_forward
    return F.conv2d(input, weight, bias, self.stride,
RuntimeError: Given groups=1, weight of size [64, 56, 1, 1], expected input[1, 128, 32, 32] to have 56 channels, but got 128 channels instead

这是训练的结构:

                 from  n    params  module                                  arguments                     
  0                -1  1      3520  models.common.Conv                      [3, 32, 6, 2, 2, 1, True]     
  1                -1  1     18560  models.common.Conv                      [32, 64, 3, 2, None, 1, True] 
  2                -1  1     18816  models.common.C3_prune                  [64, 64, 64, 1, True, 1, [0.5, 0.5], [1.0, 1.0, 1.0]]
  3                -1  1     73984  models.common.Conv                      [64, 128, 3, 2, None, 1, True]
  4                -1  2    231424  models.common.C3_prune                  [128, 128, 128, 2, True, 1, [0.5, 0.5], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]]
  5                -1  1    295424  models.common.Conv                      [128, 256, 3, 2, None, 1, True]
  6                -1  3   1875456  models.common.C3_prune                  [256, 256, 256, 3, True, 1, [0.5, 0.5], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]]
  7                -1  1   1180672  models.common.Conv                      [256, 512, 3, 2, None, 1, True]
  8                -1  1   1182720  models.common.C3_prune                  [512, 512, 512, 1, True, 1, [0.5, 0.5], [1.0, 1.0, 1.0]]
  9                -1  1    656896  models.common.SPPF_prune                [512, 512, 5, 0.5]            
 10                -1  1    131584  models.common.Conv                      [512, 256, 1, 1, None, 1, True]
 11                -1  1    262400  torch.nn.modules.conv.ConvTranspose2d   [256, 256, 2, 2, 0]           
 12           [-1, 6]  1         0  models.common.Concat                    [1]                           
 13                -1  1    361984  models.common.C3_prune                  [512, 256, 256, 1, False, 1, [0.5, 0.5], [1.0, 1.0, 1.0]]
 14                -1  1     33024  models.common.Conv                      [256, 128, 1, 1, None, 1, True]
 15                -1  1     65664  torch.nn.modules.conv.ConvTranspose2d   [128, 128, 2, 2, 0]           
 16           [-1, 4]  1         0  models.common.Concat                    [1]                           
 17                -1  1     90880  models.common.C3_prune                  [256, 128, 128, 1, False, 1, [0.5, 0.5], [1.0, 1.0, 1.0]]
 18                -1  1    147712  models.common.Conv                      [128, 128, 3, 2, None, 1, True]
 19          [-1, 14]  1         0  models.common.Concat                    [1]                           
 20                -1  1    296448  models.common.C3_prune                  [256, 256, 256, 1, False, 1, [0.5, 0.5], [1.0, 1.0, 1.0]]
 21                -1  1    590336  models.common.Conv                      [256, 256, 3, 2, None, 1, True]
 22          [-1, 10]  1         0  models.common.Concat                    [1]                           
 23                -1  1   1182720  models.common.C3_prune                  [512, 512, 512, 1, False, 1, [0.5, 0.5], [1.0, 1.0, 1.0]]
 24      [17, 20, 23]  1    229245  models.yolo_prune.Detect                [80, [[10, 13, 16, 30, 33, 23], [30, 61, 62, 45, 59, 119], [116, 90, 156, 198, 373, 326]], [128, 256, 512]]

下面是搜索最优子网结构:

                 from  n    params  module                                  arguments                     
  0                -1  1      2640  models.common.Conv                      [3, 24, 6, 2, 2, 1, True]     
  1                -1  1      6976  models.common.Conv                      [24, 32, 3, 2, None, 1, True] 
  2                -1  1      9440  models.common.C3_prune                  [32, 40, 64, 1, True, 1, [0.5, 0.375], [0.5, 1.0, 1.0]]
  3                -1  1     20272  models.common.Conv                      [40, 56, 3, 2, None, 1, True] 
  4                -1  2    212992  models.common.C3_prune                  [56, 128, 128, 2, True, 1, [0.5, 0.5], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]]
  5                -1  1    230800  models.common.Conv                      [128, 200, 3, 2, None, 1, True]
  6                -1  3   1832448  models.common.C3_prune                  [200, 256, 256, 3, True, 1, [0.5, 0.5], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]]
  7                -1  1    571888  models.common.Conv                      [256, 248, 3, 2, None, 1, True]
  8                -1  1    488832  models.common.C3_prune                  [248, 392, 512, 1, True, 1, [0.5, 0.484375], [0.25, 1.0, 1.0]]
  9                -1  1    164638  models.common.SPPF_prune                [392, 512, 5, 0.171875]       
 10                -1  1     94576  models.common.Conv                      [512, 184, 1, 1, None, 1, True]
 11                -1  1    188672  torch.nn.modules.conv.ConvTranspose2d   [184, 256, 2, 2, 0]           
 12           [-1, 6]  1         0  models.common.Concat                    [1]                           
 13                -1  1    298048  models.common.C3_prune                  [512, 104, 256, 1, False, 1, [0.5, 0.34375], [1.0, 1.0, 1.0]]
 14                -1  1      9328  models.common.Conv                      [104, 88, 1, 1, None, 1, True]
 15                -1  1     45184  torch.nn.modules.conv.ConvTranspose2d   [88, 128, 2, 2, 0]            
 16           [-1, 4]  1         0  models.common.Concat                    [1]                           
 17                -1  1     72240  models.common.C3_prune                  [256, 88, 128, 1, False, 1, [0.5, 0.3125], [0.875, 1.0, 1.0]]
 18                -1  1     50816  models.common.Conv                      [88, 64, 3, 2, None, 1, True] 
 19          [-1, 14]  1         0  models.common.Concat                    [1]                           
 20                -1  1    111984  models.common.C3_prune                  [152, 96, 256, 1, False, 1, [0.5, 0.28125], [0.375, 1.0, 1.0]]
 21                -1  1    214768  models.common.Conv                      [96, 248, 3, 2, None, 1, True]
 22          [-1, 10]  1         0  models.common.Concat                    [1]                           
 23                -1  1    610000  models.common.C3_prune                  [432, 304, 512, 1, False, 1, [0.5, 0.40625], [0.40625, 1.0, 1.0]]
 24      [17, 20, 23]  1    125205  models.yolo_prune.Detect                [80, [[10, 13, 16, 30, 33, 23], [30, 61, 62, 45, 59, 119], [116, 90, 156, 198, 373, 326]], [88, 96, 304]]