Open GeneralJing opened 3 years ago
有解决的消息了吗?是跟分组卷积有关吗?
呃这两天有别的事情,还没看多久,晚上我再自己测测
嗯 好的 辛苦大佬,有消息及时回复下,自己手动剪,感觉比较步骤比较繁琐
能贴你的 prune_by_class_bisenetv2.py 吗?我创建 lib/models/bisenetv2.py#BiSeNetV2(19)
后 graph.build_graph
没跑过去。bisnet 是刚从 https://github.com/CoinCheung/BiSeNet 找的。
import sys
sys.path.append("..")
import torch
import torchpruner
import torchvision
import numpy as np
from bisenetv2 import BiSeNetV2
#以下代码示例了对每一个BN层去除其weight系数绝对值前20%小的层
#加载模型
model = torchvision.models.vgg11_bn()
print('model-origin:', model)
#jzy
#model = BiSeNetV2(n_classes=9, aux_mode='pred')
model = BiSeNetV2(n_classes=9)
model.load_state_dict(torch.load('/home/zxz/torch-model-compression/examples/torchpruner/model_final.pth', map_location='cuda'), strict=False)
print('model-origin:', model)
# 创建ONNXGraph对象,绑定需要被剪枝的模型
graph = torchpruner.ONNXGraph(model)
##build ONNX静态图结构,需要指定输入的张量
graph.build_graph(inputs=(torch.zeros(8, 3, 320, 640),))
# 遍历所有的Module
for key in graph.modules:
module = graph.modules[key]
# 如果该module对应了BN层
if isinstance(module.nn_object, torch.nn.BatchNorm2d):
# 获取该对象
nn_object = module.nn_object
# 排序,取前20%小的权重值对应的index
weight = nn_object.weight.detach().cpu().numpy()
index = np.argsort(np.abs(weight))[: int(weight.shape[0] * 0.2)]
print('index:', index)
result = module.cut_analysis("weight", index=index, dim=0)
model, context = torchpruner.set_cut(model, result)
# 新的model即为剪枝后的模型
print('model-pruned:', model)
这个代码。
comment框是markdown格式的,我的注释在这很奇怪
markdown的多行代码块语法是 三个 `
连着表示开头和结尾,比如
# // 开头
# ``` [ + 空格 + 语言名]
# 具体内容
# // 结尾
# ```
嗯嗯 有空看看mk语法,大佬看问题 嘿嘿
你用的 bisenetv2.py 是哪个文件?BiSeNet-master/lib/models/bisenetv2.py 还是 BiSeNet-master/old/bisenetv2/bisenetv2.py ?
另外我用官方仓库的模型(old/README.md的百度网盘文件 model_final.pth)好像加载不了,大概是参数不一致……厚颜来要你的原始模型了,能发的话,网盘或者发到 gdh1995@qq.com 都行;不方便的话,我就再看看。我现在因为懒得自己在coco上训 练,缺少实际可用的参数张量,卡在前边某一个分组卷积的 cut_analysis
步骤了Orz
用的是BiSeNet-master/lib/models/bisenetv2.py这个,模型的话在公司,可以给你发一个训练几个epoch的版本,因为是其他人训练的,也不方便给最新的。得明天给你了。
嗯谢谢!
Traceback (most recent call last):
File "G:/py/torch-model-compression-main/examples/torchpruner/prune_and_recovery.py", line 16, in
好几个都是 加载模型 时出现上述错误。我是刚入门的小白,请问这个怎么处理,程序跑不通,不知道怎么改这个。 model = torchvision.models.resnet50()
@jzy-hxf 抱歉之前没注意消息。具体到你这个问题,是torch版本比较新(1.11还是多少以上)造成的。你把这个项目里出现的 _retain_param_name
都去掉就行了,不影响结果。
@GeneralJing 抱歉我好久没注意这个。我确认了几遍代码,应该是 examples/torchpruner/prune_by_class.py
写的有问题,每次执行 torchpruner.set_cut
后 graph
的部分信息会过时,所以需要重新创建 graph
。graph.modules
是稳定的,可以预先算好 keys
for key in list(graph.modules):
# ...
model, context = torchpruner.set_cut(model, result)
graph = torchpruner.ONNXGraph(model) # 本行可以省略
graph.build_graph(inputs=(torch.zeros(1, 3, 224, 224),))
我一会去改一下示例代码。
“下个周末看一下”,重新定义了下个周末哈哈哈哈,好的,多谢。
参考你的例子,对BiSeNetv2模型剪枝,报如下错误,这个怎么解决,能看一下吗
index: [34 53 31 46 12 49 36 60 18 68 39 57 67 13 14]
key: self.segment.S3.0.dwconv1.1
Traceback (most recent call last):
File "model_pruning.py", line 33, in <module>
result = module.cut_analysis("weight", index=index, dim=0)
File "/usr/local/lib/python3.6/dist-packages/torchpruner-0.1.0-py3.6.egg/torchpruner/graph.py", line 333, in cut_analysis
File "/usr/local/lib/python3.6/dist-packages/torchpruner-0.1.0-py3.6.egg/torchpruner/graph.py", line 246, in cut_analysis
File "/usr/local/lib/python3.6/dist-packages/torchpruner-0.1.0-py3.6.egg/torchpruner/graph.py", line 269, in cut_analysis_with_mask
File "/usr/local/lib/python3.6/dist-packages/torchpruner-0.1.0-py3.6.egg/torchpruner/operator/onnx_operator.py", line 489, in analysis
File "/usr/local/lib/python3.6/dist-packages/torchpruner-0.1.0-py3.6.egg/torchpruner/mask_utils.py", line 276, in indexs
RuntimeError: All the data is masked
我的代码如下
import imp
import os
import sys
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.join(BASE_DIR, '..'))
import torch
import torchpruner
import numpy as np
from models.bisenetv2 import BiSeNetV2
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 加载模型
model = BiSeNetV2(n_classes=4)
checkpoint = torch.load('../results/05-05_01-59/checkpoint_best.pkl', map_location='cpu')
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
model.eval()
graph = torchpruner.ONNXGraph(model)
graph.build_graph(inputs=(torch.zeros(1, 3, 640, 512),))
for key in list(graph.modules):
module = graph.modules[key]
if isinstance(module.nn_object, torch.nn.BatchNorm2d):
nn_object = module.nn_object
weight = nn_object.weight.detach().cpu().numpy()
index = np.argsort(np.abs(weight))[: int(weight.shape[0] * 0.2)]
print('index:', index)
print('key:', key)
result = module.cut_analysis("weight", index=index, dim=0)
model, context = torchpruner.set_cut(model, result)
graph = torchpruner.ONNXGraph(model)
graph.build_graph(inputs=(torch.zeros(1, 3, 640, 512),))
print('model-pruned:', model)
torch.save(model, '../results/05-05_01-59/checkpoint_best_pruning.pkl')