midasklr / yolov5prune

553 stars 113 forks source link

用yolov5m或者yolov5l完成稀疏训练后剪枝报错 #62

Open LiyuyangSWJTU opened 2 years ago

LiyuyangSWJTU commented 2 years ago

大佬您好,谢谢你的开源,请问我用yolov5m完成稀疏训练过后,在用prune.py剪枝的过程中出现了以下报错,希望大佬解答一下在m,l模型剪枝的时候需要对prune.py做哪些修改,还是我这里是其他错误的原因,十分感谢感谢感谢 from n params module arguments 0 -1 1 5280 models.common.Focus [3, 48, 3] 1 -1 1 41664 models.common.Conv [48, 96, 3, 2] 2 -1 1 42048 models.pruned_common.C3Pruned [96, 48, 48, 96, [[48, 48, 48]], 1, 128] 3 -1 1 165406 models.common.Conv [96, 191, 3, 2] 4 -1 1 351936 models.pruned_common.C3Pruned [191, 96, 96, 192, [[96, 96, 96], [96, 96, 96], [96, 96, 96]], 3, 256] 5 -1 1 557060 models.common.Conv [192, 322, 3, 2] 6 -1 1 1251772 models.pruned_common.C3Pruned [322, 192, 42, 287, [[192, 192, 192], [192, 192, 192], [192, 192, 192]], 3, 512] 7 -1 1 1804330 models.common.Conv [287, 698, 3, 2] 8 -1 1 293770 models.pruned_common.SPPPruned [698, 193, 205, [5, 9, 13]] 9 -1 1 11340 models.pruned_common.C3Pruned [205, 7, 31, 48, [[7, 16, 9]], 1, False] 10 -1 1 3350 models.common.Conv [48, 67, 1, 1] 11 -1 1 0 torch.nn.modules.upsampling.Upsample [None, 2, 'nearest'] 12 [-1, 6] 1 0 models.common.Concat [1] 13 -1 1 362558 models.pruned_common.C3Pruned [354, 116, 133, 244, [[116, 137, 152]], 1, False] 14 -1 1 44772 models.common.Conv [244, 182, 1, 1] 15 -1 1 0 torch.nn.modules.upsampling.Upsample [None, 2, 'nearest'] 16 [-1, 4] 1 0 models.common.Concat [1] 17 -1 1 199299 models.pruned_common.C3Pruned [374, 96, 95, 191, [[96, 94, 96]], 1, False] 18 -1 1 213404 models.common.Conv [191, 124, 3, 2] 19 [-1, 14] 1 0 models.common.Concat [1] 20 -1 1 426578 models.pruned_common.C3Pruned [306, 115, 150, 377, [[115, 148, 158]], 1, False] 21 -1 1 682395 models.common.Conv [377, 201, 3, 2] 22 [-1, 10] 1 0 models.common.Concat [1] 23 -1 1 230556 models.pruned_common.C3Pruned [268, 77, 189, 482, [[77, 45, 71]], 1, False] detect input : ['model.0.conv.bn', 'model.1.bn', 'model.2.cv3.bn', 'model.3.bn', 'model.4.cv3.bn', 'model.5.bn', 'model.6.cv3.bn', 'model.7.bn', 'model.8.cv2.bn', 'model.9.cv3.bn', 'model.10.bn', 'model.10.bn', ['model.10.bn', 'model.6.cv3.bn'], 'model.13.cv3.bn', 'model.14.bn', 'model.14.bn', ['model.14.bn', 'model.4.cv3.bn'], 'model.17.cv3.bn', 'model.18.bn', ['model.18.bn', 'model.14.bn'], 'model.20.cv3.bn', 'model.21.bn', ['model.21.bn', 'model.10.bn'], 'model.23.cv3.bn'] 24 [17, 20, 23] 24 [17, 20, 23] 1 53703 models.yolo.Detect [12, [[10, 13, 16, 30, 33, 23], [30, 61, 62, 45, 59, 119], [116, 90, 156, 198, 373, 326]], [191, 377, 482]] Model Summary: 283 layers, 6741221 parameters, 6741221 gradients, 26.3 GFLOPS

Traceback (most recent call last): File "prune.py", line 807, in opt=opt File "prune.py", line 462, in test_prune assert pruned_model_state.keys() == modelstate.keys() AssertionError

midasklr commented 2 years ago

剪枝前后模型的state_dict()的键值不一样,可以把assert改为if看看哪里key没对应上。

lc790 commented 2 years ago

我也遇到了同样的问题,不知道怎么解决

baidingyuan commented 2 years ago

大佬您好,谢谢你的开源,请问我用yolov5m完成稀疏训练过后,在用prune.py剪枝的过程中出现了以下报错,希望大佬解答一下在m,l模型剪枝的时候需要对prune.py做哪些修改,还是我这里是其他错误的原因,十分感谢感谢感谢 from n params module arguments 0 -1 1 5280 models.common.Focus [3, 48, 3] 1 -1 1 41664 models.common.Conv [48, 96, 3, 2] 2 -1 1 42048 models.pruned_common.C3Pruned [96, 48, 48, 96, [[48, 48, 48]], 1, 128] 3 -1 1 165406 models.common.Conv [96, 191, 3, 2] 4 -1 1 351936 models.pruned_common.C3Pruned [191, 96, 96, 192, [[96, 96, 96], [96, 96, 96], [96, 96, 96]], 3, 256] 5 -1 1 557060 models.common.Conv [192, 322, 3, 2] 6 -1 1 1251772 models.pruned_common.C3Pruned [322, 192, 42, 287, [[192, 192, 192], [192, 192, 192], [192, 192, 192]], 3, 512] 7 -1 1 1804330 models.common.Conv [287, 698, 3, 2] 8 -1 1 293770 models.pruned_common.SPPPruned [698, 193, 205, [5, 9, 13]] 9 -1 1 11340 models.pruned_common.C3Pruned [205, 7, 31, 48, [[7, 16, 9]], 1, False] 10 -1 1 3350 models.common.Conv [48, 67, 1, 1] 11 -1 1 0 torch.nn.modules.upsampling.Upsample [None, 2, 'nearest'] 12 [-1, 6] 1 0 models.common.Concat [1] 13 -1 1 362558 models.pruned_common.C3Pruned [354, 116, 133, 244, [[116, 137, 152]], 1, False] 14 -1 1 44772 models.common.Conv [244, 182, 1, 1] 15 -1 1 0 torch.nn.modules.upsampling.Upsample [None, 2, 'nearest'] 16 [-1, 4] 1 0 models.common.Concat [1] 17 -1 1 199299 models.pruned_common.C3Pruned [374, 96, 95, 191, [[96, 94, 96]], 1, False] 18 -1 1 213404 models.common.Conv [191, 124, 3, 2] 19 [-1, 14] 1 0 models.common.Concat [1] 20 -1 1 426578 models.pruned_common.C3Pruned [306, 115, 150, 377, [[115, 148, 158]], 1, False] 21 -1 1 682395 models.common.Conv [377, 201, 3, 2] 22 [-1, 10] 1 0 models.common.Concat [1] 23 -1 1 230556 models.pruned_common.C3Pruned [268, 77, 189, 482, [[77, 45, 71]], 1, False] detect input : ['model.0.conv.bn', 'model.1.bn', 'model.2.cv3.bn', 'model.3.bn', 'model.4.cv3.bn', 'model.5.bn', 'model.6.cv3.bn', 'model.7.bn', 'model.8.cv2.bn', 'model.9.cv3.bn', 'model.10.bn', 'model.10.bn', ['model.10.bn', 'model.6.cv3.bn'], 'model.13.cv3.bn', 'model.14.bn', 'model.14.bn', ['model.14.bn', 'model.4.cv3.bn'], 'model.17.cv3.bn', 'model.18.bn', ['model.18.bn', 'model.14.bn'], 'model.20.cv3.bn', 'model.21.bn', ['model.21.bn', 'model.10.bn'], 'model.23.cv3.bn'] 24 [17, 20, 23] 24 [17, 20, 23] 1 53703 models.yolo.Detect [12, [[10, 13, 16, 30, 33, 23], [30, 61, 62, 45, 59, 119], [116, 90, 156, 198, 373, 326]], [191, 377, 482]] Model Summary: 283 layers, 6741221 parameters, 6741221 gradients, 26.3 GFLOPS

Traceback (most recent call last): File "prune.py", line 807, in opt=opt File "prune.py", line 462, in test_prune assert pruned_model_state.keys() == modelstate.keys() AssertionError

请问你这个问题解决了吗?我使用X版本也遇到了这个问题

yezechen commented 2 years ago

大佬您好,谢谢你的开源,请问我用yolov5m完成稀疏训练过后,在用prune.py剪枝的过程中出现了以下报错,希望大佬解答一下在m,l模型剪枝的时候需要对prune.py做哪些修改,还是我这里是其他错误的原因,十分感谢感谢感谢 from n params module arguments 0 -1 1 5280 models.common.Focus [3, 48, 3] 1 -1 1 41664 models.common.Conv [48, 96, 3, 2] 2 -1 1 42048 models.pruned_common.C3Pruned [96, 48, 48, 96, [[48, 48, 48]], 1, 128] 3 -1 1 165406 models.common.Conv [96, 191, 3, 2] 4 -1 1 351936 models.pruned_common.C3Pruned [191, 96, 96, 192, [[96, 96, 96], [96, 96, 96], [96, 96, 96]], 3, 256] 5 -1 1 557060 models.common.Conv [192, 322, 3, 2] 6 -1 1 1251772 models.pruned_common.C3Pruned [322, 192, 42, 287, [[192, 192, 192], [192, 192, 192], [192, 192, 192]], 3, 512] 7 -1 1 1804330 models.common.Conv [287, 698, 3, 2] 8 -1 1 293770 models.pruned_common.SPPPruned [698, 193, 205, [5, 9, 13]] 9 -1 1 11340 models.pruned_common.C3Pruned [205, 7, 31, 48, [[7, 16, 9]], 1, False] 10 -1 1 3350 models.common.Conv [48, 67, 1, 1] 11 -1 1 0 torch.nn.modules.upsampling.Upsample [None, 2, 'nearest'] 12 [-1, 6] 1 0 models.common.Concat [1] 13 -1 1 362558 models.pruned_common.C3Pruned [354, 116, 133, 244, [[116, 137, 152]], 1, False] 14 -1 1 44772 models.common.Conv [244, 182, 1, 1] 15 -1 1 0 torch.nn.modules.upsampling.Upsample [None, 2, 'nearest'] 16 [-1, 4] 1 0 models.common.Concat [1] 17 -1 1 199299 models.pruned_common.C3Pruned [374, 96, 95, 191, [[96, 94, 96]], 1, False] 18 -1 1 213404 models.common.Conv [191, 124, 3, 2] 19 [-1, 14] 1 0 models.common.Concat [1] 20 -1 1 426578 models.pruned_common.C3Pruned [306, 115, 150, 377, [[115, 148, 158]], 1, False] 21 -1 1 682395 models.common.Conv [377, 201, 3, 2] 22 [-1, 10] 1 0 models.common.Concat [1] 23 -1 1 230556 models.pruned_common.C3Pruned [268, 77, 189, 482, [[77, 45, 71]], 1, False] detect input : ['model.0.conv.bn', 'model.1.bn', 'model.2.cv3.bn', 'model.3.bn', 'model.4.cv3.bn', 'model.5.bn', 'model.6.cv3.bn', 'model.7.bn', 'model.8.cv2.bn', 'model.9.cv3.bn', 'model.10.bn', 'model.10.bn', ['model.10.bn', 'model.6.cv3.bn'], 'model.13.cv3.bn', 'model.14.bn', 'model.14.bn', ['model.14.bn', 'model.4.cv3.bn'], 'model.17.cv3.bn', 'model.18.bn', ['model.18.bn', 'model.14.bn'], 'model.20.cv3.bn', 'model.21.bn', ['model.21.bn', 'model.10.bn'], 'model.23.cv3.bn'] 24 [17, 20, 23] 24 [17, 20, 23] 1 53703 models.yolo.Detect [12, [[10, 13, 16, 30, 33, 23], [30, 61, 62, 45, 59, 119], [116, 90, 156, 198, 373, 326]], [191, 377, 482]] Model Summary: 283 layers, 6741221 parameters, 6741221 gradients, 26.3 GFLOPS

Traceback (most recent call last): File "prune.py", line 807, in opt=opt File "prune.py", line 462, in test_prune assert pruned_model_state.keys() == modelstate.keys() AssertionError

你看一下prune.py文件里,390行那里有一个配置文件信息,默认的是yolov5s,你手动把它改成yolov5m或者l就可以了

gravityzmk commented 2 years ago

老哥解决了吗,我也是一模一样的问题,Traceback (most recent call last): File "prune.py", line 807, in opt=opt File "prune.py", line 462, in test_prune assert pruned_model_state.keys() == modelstate.keys() AssertionError

LiyuyangSWJTU commented 2 years ago

大佬您好,谢谢你的开源,请问我用yolov5m完成稀疏训练过后,在用prune.py剪枝的过程中出现了以下报错,希望大佬解答一下在m,l模型剪枝的时候需要对prune.py做哪些修改,还是我这里是其他错误的原因,十分感谢感谢感谢 from n params module arguments 0 -1 1 5280 models.common.Focus [3, 48, 3] 1 -1 1 41664 models.common.Conv [48, 96, 3, 2] 2 -1 1 42048 models.pruned_common.C3Pruned [96, 48, 48, 96, [[48, 48, 48]], 1, 128] 3 -1 1 165406 models.common.Conv [96, 191, 3, 2] 4 -1 1 351936 models.pruned_common.C3Pruned [191, 96, 96, 192, [[96, 96, 96], [96, 96, 96], [96, 96, 96]], 3, 256] 5 -1 1 557060 models.common.Conv [192, 322, 3, 2] 6 -1 1 1251772 models.pruned_common.C3Pruned [322, 192, 42, 287, [[192, 192, 192], [192, 192, 192], [192, 192, 192]], 3, 512] 7 -1 1 1804330 models.common.Conv [287, 698, 3, 2] 8 -1 1 293770 models.pruned_common.SPPPruned [698, 193, 205, [5, 9, 13]] 9 -1 1 11340 models.pruned_common.C3Pruned [205, 7, 31, 48, [[7, 16, 9]], 1, False] 10 -1 1 3350 models.common.Conv [48, 67, 1, 1] 11 -1 1 0 torch.nn.modules.upsampling.Upsample [None, 2, 'nearest'] 12 [-1, 6] 1 0 models.common.Concat [1] 13 -1 1 362558 models.pruned_common.C3Pruned [354, 116, 133, 244, [[116, 137, 152]], 1, False] 14 -1 1 44772 models.common.Conv [244, 182, 1, 1] 15 -1 1 0 torch.nn.modules.upsampling.Upsample [None, 2, 'nearest'] 16 [-1, 4] 1 0 models.common.Concat [1] 17 -1 1 199299 models.pruned_common.C3Pruned [374, 96, 95, 191, [[96, 94, 96]], 1, False] 18 -1 1 213404 models.common.Conv [191, 124, 3, 2] 19 [-1, 14] 1 0 models.common.Concat [1] 20 -1 1 426578 models.pruned_common.C3Pruned [306, 115, 150, 377, [[115, 148, 158]], 1, False] 21 -1 1 682395 models.common.Conv [377, 201, 3, 2] 22 [-1, 10] 1 0 models.common.Concat [1] 23 -1 1 230556 models.pruned_common.C3Pruned [268, 77, 189, 482, [[77, 45, 71]], 1, False] detect input : ['model.0.conv.bn', 'model.1.bn', 'model.2.cv3.bn', 'model.3.bn', 'model.4.cv3.bn', 'model.5.bn', 'model.6.cv3.bn', 'model.7.bn', 'model.8.cv2.bn', 'model.9.cv3.bn', 'model.10.bn', 'model.10.bn', ['model.10.bn', 'model.6.cv3.bn'], 'model.13.cv3.bn', 'model.14.bn', 'model.14.bn', ['model.14.bn', 'model.4.cv3.bn'], 'model.17.cv3.bn', 'model.18.bn', ['model.18.bn', 'model.14.bn'], 'model.20.cv3.bn', 'model.21.bn', ['model.21.bn', 'model.10.bn'], 'model.23.cv3.bn'] 24 [17, 20, 23] 24 [17, 20, 23] 1 53703 models.yolo.Detect [12, [[10, 13, 16, 30, 33, 23], [30, 61, 62, 45, 59, 119], [116, 90, 156, 198, 373, 326]], [191, 377, 482]] Model Summary: 283 layers, 6741221 parameters, 6741221 gradients, 26.3 GFLOPS Traceback (most recent call last): File "prune.py", line 807, in opt=opt File "prune.py", line 462, in test_prune assert pruned_model_state.keys() == modelstate.keys() AssertionError

你看一下prune.py文件里,390行那里有一个配置文件信息,默认的是yolov5s,你手动把它改成yolov5m或者l就可以了

感谢大佬感谢大佬,问题解决了

LiyuyangSWJTU commented 2 years ago

老哥解决了吗,我也是一模一样的问题,Traceback (most recent call last): File "prune.py", line 807, in opt=opt File "prune.py", line 462, in test_prune assert pruned_model_state.keys() == modelstate.keys() AssertionError

yezechen大佬的答案把问题解决了

midasklr commented 2 years ago

最新代码已经支持s/m/l/x模型, 或者手动更改: https://github.com/midasklr/yolov5prune/blob/33f120cddbf0fc040c6f13b318728931aa3768dc/prune.py#L428-L429 这里网络的宽度和深度为v5m或者l模型