taohan10200 / IIM

PyTorch implementations of the paper: "Learning Independent Instance Maps for Crowd Localization"
MIT License
163 stars 39 forks source link

怎样转化为onnx模型 #26

Open huaxiangwangman opened 2 years ago

taohan10200 commented 2 years ago

我们提供的是pytorch保存的模型,onnx格式的模型可自通过我们开源的模型参数自行转换

csz-006 commented 1 year ago

请问你转换成功了吗,可以看下转onnx的代码嘛

csz-006 commented 10 months ago

这个是转换为onnx的代码,需要注意输入为13h*w import os from tkinter.messagebox import NO import torch import torch.onnx import torch.nn as nn import onnxruntime as ort import numpy as np import torch.nn.functional as F from tqdm import tqdm import onnx from onnxsim import simplify from model.locator import Crowd_locator from collections import OrderedDict

os.environ['CUDA_VISIBLE_DEVICES']= '1'

GPU_ID = '0' os.environ["CUDA_VISIBLE_DEVICES"] = GPU_ID torch.backends.cudnn.benchmark = True

def onnx_export(model_path): net = Crowd_locator(netName,GPU_ID,pretrained=False) net.cuda() state_dict = torch.load(model_path) if len(GPU_ID.split(',')) > 1: net.load_state_dict(state_dict) else: new_state_dict = OrderedDict() for k, v in state_dict.items(): name = k.replace('module.', '') new_state_dict[name] = v net.load_state_dict(new_state_dict) net.eval()

打印模型的每一层名称

for name, module in net.named_modules():
    print(name,'nnnnn')
dummpy_input = torch.zeros(1, 3, 512, 1024).cuda()  # 640 640
# dummpy_input = torch.zeros(1, 3, 512, 1024).cuda()
onnx_name = 'HRnet_Crowd_count_512_1024_opset12.onnx'
# net = net(dummpy_input)
torch.onnx.export(
    net, dummpy_input, onnx_name,
    verbose=True,
    input_names=['image'],
    output_names=['predict'],
    opset_version=12,
    dynamic_axes=None 
)

def onnx_sim(onnx_path): model_onnx = onnx.load_model(onnx_path) model_smi, check = simplify(model_onnx) save_path = 'HRnet_Crowd_count_512_1024_opset12-sim.onnx' onnx.save(model_smi, save_path) print('模型静态图简化完成')

if name == 'main': netName = 'HR_Net' # VGG16_FPN HR_Net model_path = '/IIM/Preweights/NWPU-HR-ep_241_F1_0.802_Pre_0.841_Rec_0.766_mae_55.6_mse_330.9.pth'

onnx_path = '/IIM/Preweights/1024_HRnet_Crowd_count_512_1024_opset12.onnx'
# save_model(pth_file)
onnx_export(model_path)
# onnx_sim(onnx_path)
print('Done')