I wanna convert pth to onnx format. This is my code:
import torch
from model.faster_rcnn.vgg16 import vgg16
from model.faster_rcnn.resnet import resnet
import numpy as np
from torch.autograd import Variable
def load_model(model, pretrained_path):
print('Loading pretrained model from {}'.format(pretrained_path))
pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage)
model.load_state_dict(pretrained_dict, strict=False)
return model
net = resnet(pascal_classes, 101, pretrained=False, class_agnostic=False)
net.create_architecture()
checkpoint = torch.load(raw_weights)
for k in checkpoint.keys():
print(k)
net.load_state_dict(checkpoint['model'])
but when I run it, I got this error. How can I fix it?Thanks!
Traceback (most recent call last):
File "pth2onnx.py", line 67, in
torch_out = torch.onnx.export(net, inputs, output_onnx, export_params=True, verbose=False,keep_initializers_as_inputs=True, input_names=input_names, output_names=output_names)
File "/usr/local/lib/python3.6/dist-packages/torch/onnx/init.py", line 276, in export
custom_opsets, enable_onnx_checker, use_external_data_format)
File "/usr/local/lib/python3.6/dist-packages/torch/onnx/utils.py", line 94, in export
use_external_data_format=use_external_data_format)
File "/usr/local/lib/python3.6/dist-packages/torch/onnx/utils.py", line 701, in _export
dynamic_axes=dynamic_axes)
File "/usr/local/lib/python3.6/dist-packages/torch/onnx/utils.py", line 459, in _model_to_graph
use_new_jit_passes)
File "/usr/local/lib/python3.6/dist-packages/torch/onnx/utils.py", line 420, in _create_jit_graph
graph, torch_out = _trace_and_get_graph_from_model(model, args)
File "/usr/local/lib/python3.6/dist-packages/torch/onnx/utils.py", line 380, in _trace_and_get_graph_from_model
torch.jit._get_trace_graph(model, args, strict=False, _force_outplace=False, _return_inputs_states=True)
File "/usr/local/lib/python3.6/dist-packages/torch/jit/_trace.py", line 1139, in _get_trace_graph
outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, kwargs)
File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, kwargs)
File "/usr/local/lib/python3.6/dist-packages/torch/jit/_trace.py", line 130, in forward
self._force_outplace,
File "/usr/local/lib/python3.6/dist-packages/torch/jit/_trace.py", line 116, in wrapper
outs.append(self.inner(trace_inputs))
File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 887, in _call_impl
result = self._slow_forward(input, kwargs)
File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 860, in _slow_forward
result = self.forward(*input, kwargs)
TypeError: forward() missing 3 required positional arguments: 'im_info', 'gt_boxes', and 'num_boxes'
I wanna convert pth to onnx format. This is my code:
import torch from model.faster_rcnn.vgg16 import vgg16 from model.faster_rcnn.resnet import resnet import numpy as np from torch.autograd import Variable
def load_model(model, pretrained_path): print('Loading pretrained model from {}'.format(pretrained_path)) pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage)
model.load_state_dict(pretrained_dict, strict=False) return model
output_onnx = './output.onnx' raw_weights = './faster_rcnn_1_10_2504.pth' pascal_classes = np.asarray(['background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'])
load weight
net = resnet(pascal_classes, 101, pretrained=False, class_agnostic=False) net.create_architecture() checkpoint = torch.load(raw_weights) for k in checkpoint.keys(): print(k)
net.load_state_dict(checkpoint['model'])
initilize the tensor holder here.
im_data = torch.FloatTensor(1) im_info = torch.FloatTensor(1) num_boxes = torch.LongTensor(1) gt_boxes = torch.FloatTensor(1)
ship to cuda
im_data = im_data.cuda() im_info = im_info.cuda() num_boxes = num_boxes.cuda() gt_boxes = gt_boxes.cuda()
make variable
im_data = Variable(im_data, volatile=True) im_info = Variable(im_info, volatile=True) num_boxes = Variable(num_boxes, volatile=True) gt_boxes = Variable(gt_boxes, volatile=True)
net.eval() print('Finished loading model!') device = torch.device("cuda") net = net.to(device)
input_names = ["input0"] output_names = ["output0"] inputs = torch.randn(1, 3, 300, 300).to(device)
output model
torch_out = torch.onnx.export(net, inputs, output_onnx, export_params=True, verbose=False,keep_initializers_as_inputs=True, input_names=input_names, output_names=output_names)
but when I run it, I got this error. How can I fix it?Thanks! Traceback (most recent call last): File "pth2onnx.py", line 67, in
torch_out = torch.onnx.export(net, inputs, output_onnx, export_params=True, verbose=False,keep_initializers_as_inputs=True, input_names=input_names, output_names=output_names)
File "/usr/local/lib/python3.6/dist-packages/torch/onnx/init.py", line 276, in export
custom_opsets, enable_onnx_checker, use_external_data_format)
File "/usr/local/lib/python3.6/dist-packages/torch/onnx/utils.py", line 94, in export
use_external_data_format=use_external_data_format)
File "/usr/local/lib/python3.6/dist-packages/torch/onnx/utils.py", line 701, in _export
dynamic_axes=dynamic_axes)
File "/usr/local/lib/python3.6/dist-packages/torch/onnx/utils.py", line 459, in _model_to_graph
use_new_jit_passes)
File "/usr/local/lib/python3.6/dist-packages/torch/onnx/utils.py", line 420, in _create_jit_graph
graph, torch_out = _trace_and_get_graph_from_model(model, args)
File "/usr/local/lib/python3.6/dist-packages/torch/onnx/utils.py", line 380, in _trace_and_get_graph_from_model
torch.jit._get_trace_graph(model, args, strict=False, _force_outplace=False, _return_inputs_states=True)
File "/usr/local/lib/python3.6/dist-packages/torch/jit/_trace.py", line 1139, in _get_trace_graph
outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, kwargs)
File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, kwargs)
File "/usr/local/lib/python3.6/dist-packages/torch/jit/_trace.py", line 130, in forward
self._force_outplace,
File "/usr/local/lib/python3.6/dist-packages/torch/jit/_trace.py", line 116, in wrapper
outs.append(self.inner(trace_inputs))
File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 887, in _call_impl
result = self._slow_forward(input, kwargs)
File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 860, in _slow_forward
result = self.forward(*input, kwargs)
TypeError: forward() missing 3 required positional arguments: 'im_info', 'gt_boxes', and 'num_boxes'