jwyang / faster-rcnn.pytorch

A faster pytorch implementation of faster r-cnn
MIT License
7.7k stars 2.33k forks source link

Try to convert pth to onnx. But error #898

Open caizhaoxin opened 2 years ago

caizhaoxin commented 2 years ago

I wanna convert pth to onnx format. This is my code:

import torch from model.faster_rcnn.vgg16 import vgg16 from model.faster_rcnn.resnet import resnet import numpy as np from torch.autograd import Variable

def load_model(model, pretrained_path): print('Loading pretrained model from {}'.format(pretrained_path)) pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage)
model.load_state_dict(pretrained_dict, strict=False) return model

output_onnx = './output.onnx' raw_weights = './faster_rcnn_1_10_2504.pth' pascal_classes = np.asarray(['background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'])

load weight

net = resnet(pascal_classes, 101, pretrained=False, class_agnostic=False) net.create_architecture() checkpoint = torch.load(raw_weights) for k in checkpoint.keys(): print(k)
net.load_state_dict(checkpoint['model'])

initilize the tensor holder here.

im_data = torch.FloatTensor(1) im_info = torch.FloatTensor(1) num_boxes = torch.LongTensor(1) gt_boxes = torch.FloatTensor(1)

ship to cuda

im_data = im_data.cuda() im_info = im_info.cuda() num_boxes = num_boxes.cuda() gt_boxes = gt_boxes.cuda()

make variable

im_data = Variable(im_data, volatile=True) im_info = Variable(im_info, volatile=True) num_boxes = Variable(num_boxes, volatile=True) gt_boxes = Variable(gt_boxes, volatile=True)

net.eval() print('Finished loading model!') device = torch.device("cuda") net = net.to(device)

input_names = ["input0"] output_names = ["output0"] inputs = torch.randn(1, 3, 300, 300).to(device)

output model

torch_out = torch.onnx.export(net, inputs, output_onnx, export_params=True, verbose=False,keep_initializers_as_inputs=True, input_names=input_names, output_names=output_names)

but when I run it, I got this error. How can I fix it?Thanks! Traceback (most recent call last): File "pth2onnx.py", line 67, in torch_out = torch.onnx.export(net, inputs, output_onnx, export_params=True, verbose=False,keep_initializers_as_inputs=True, input_names=input_names, output_names=output_names) File "/usr/local/lib/python3.6/dist-packages/torch/onnx/init.py", line 276, in export custom_opsets, enable_onnx_checker, use_external_data_format) File "/usr/local/lib/python3.6/dist-packages/torch/onnx/utils.py", line 94, in export use_external_data_format=use_external_data_format) File "/usr/local/lib/python3.6/dist-packages/torch/onnx/utils.py", line 701, in _export dynamic_axes=dynamic_axes) File "/usr/local/lib/python3.6/dist-packages/torch/onnx/utils.py", line 459, in _model_to_graph use_new_jit_passes) File "/usr/local/lib/python3.6/dist-packages/torch/onnx/utils.py", line 420, in _create_jit_graph graph, torch_out = _trace_and_get_graph_from_model(model, args) File "/usr/local/lib/python3.6/dist-packages/torch/onnx/utils.py", line 380, in _trace_and_get_graph_from_model torch.jit._get_trace_graph(model, args, strict=False, _force_outplace=False, _return_inputs_states=True) File "/usr/local/lib/python3.6/dist-packages/torch/jit/_trace.py", line 1139, in _get_trace_graph outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, kwargs) File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 889, in _call_impl result = self.forward(*input, kwargs) File "/usr/local/lib/python3.6/dist-packages/torch/jit/_trace.py", line 130, in forward self._force_outplace, File "/usr/local/lib/python3.6/dist-packages/torch/jit/_trace.py", line 116, in wrapper outs.append(self.inner(trace_inputs)) File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 887, in _call_impl result = self._slow_forward(input, kwargs) File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 860, in _slow_forward result = self.forward(*input, kwargs) TypeError: forward() missing 3 required positional arguments: 'im_info', 'gt_boxes', and 'num_boxes'

sisinote commented 1 year ago

Did you finish the onnx conversion ?