Open Wq-dd opened 1 year ago
你好,我在使用resnet18为主干网的retinanet时,自己使用稀疏训练后的模型剪枝会报错,我的做法是:
下面是我的部分代码。
import torchpruner # 创建ONNXGraph对象,绑定需要被剪枝的模型 self.model.eval() graph = torchpruner.ONNXGraph(self.model.cpu()) ##build ONNX静态图结构,需要指定输入的张量 graph.build_graph(inputs=(torch.zeros(1, 3, 640, 640),)) for i, (k, v) in enumerate(mask_dict_for_pruner.items()): # 获取conv1模块对应的module conv1_module = graph.modules[k] # 对前四个通道进行剪枝分析,指定对weight权重进行剪枝,剪枝前四个通道 # weight权重out_channels对应的通道维度为0 result = conv1_module.cut_analysis(attribute_name="weight", index=v, dim=0) # 剪枝执行模块执行剪枝操作,对模型完成剪枝过程.context变量提供了用于剪枝恢复的上下文 self.model, context = torchpruner.set_cut(self.model, result) # 新的model即为剪枝后的模型 print(self.model)```
请问是我的用法不对吗?还是说这种先计算剪枝的索引再调用torchpruner的方法不对呢?
每次剪枝后,model 对象变了,就都要重建 graph、重新执行 build_graph
你好,我在使用resnet18为主干网的retinanet时,自己使用稀疏训练后的模型剪枝会报错,我的做法是:
下面是我的部分代码。
请问是我的用法不对吗?还是说这种先计算剪枝的索引再调用torchpruner的方法不对呢?