Open huaxiangwangman opened 2 years ago
请问你转换成功了吗,可以看下转onnx的代码嘛
这个是转换为onnx的代码,需要注意输入为13h*w import os from tkinter.messagebox import NO import torch import torch.onnx import torch.nn as nn import onnxruntime as ort import numpy as np import torch.nn.functional as F from tqdm import tqdm import onnx from onnxsim import simplify from model.locator import Crowd_locator from collections import OrderedDict
GPU_ID = '0' os.environ["CUDA_VISIBLE_DEVICES"] = GPU_ID torch.backends.cudnn.benchmark = True
def onnx_export(model_path): net = Crowd_locator(netName,GPU_ID,pretrained=False) net.cuda() state_dict = torch.load(model_path) if len(GPU_ID.split(',')) > 1: net.load_state_dict(state_dict) else: new_state_dict = OrderedDict() for k, v in state_dict.items(): name = k.replace('module.', '') new_state_dict[name] = v net.load_state_dict(new_state_dict) net.eval()
for name, module in net.named_modules():
print(name,'nnnnn')
dummpy_input = torch.zeros(1, 3, 512, 1024).cuda() # 640 640
# dummpy_input = torch.zeros(1, 3, 512, 1024).cuda()
onnx_name = 'HRnet_Crowd_count_512_1024_opset12.onnx'
# net = net(dummpy_input)
torch.onnx.export(
net, dummpy_input, onnx_name,
verbose=True,
input_names=['image'],
output_names=['predict'],
opset_version=12,
dynamic_axes=None
)
def onnx_sim(onnx_path): model_onnx = onnx.load_model(onnx_path) model_smi, check = simplify(model_onnx) save_path = 'HRnet_Crowd_count_512_1024_opset12-sim.onnx' onnx.save(model_smi, save_path) print('模型静态图简化完成')
if name == 'main': netName = 'HR_Net' # VGG16_FPN HR_Net model_path = '/IIM/Preweights/NWPU-HR-ep_241_F1_0.802_Pre_0.841_Rec_0.766_mae_55.6_mse_330.9.pth'
onnx_path = '/IIM/Preweights/1024_HRnet_Crowd_count_512_1024_opset12.onnx'
# save_model(pth_file)
onnx_export(model_path)
# onnx_sim(onnx_path)
print('Done')
我们提供的是pytorch保存的模型,onnx格式的模型可自通过我们开源的模型参数自行转换