apache / mxnet

Lightweight, Portable, Flexible Distributed/Mobile Deep Learning with Dynamic, Mutation-aware Dataflow Dep Scheduler; for Python, R, Julia, Scala, Go, Javascript and more
https://mxnet.apache.org
Apache License 2.0
20.78k stars 6.79k forks source link

mxnet model convert to onnx: No conversion function registered for op type LinearRegressionOutput yet. #15092

Open eeqmcc opened 5 years ago

eeqmcc commented 5 years ago

INFO:root:Converting json and weight file to sym and params Traceback (most recent call last): File "test.py", line 11, in converted_model_path = onnx_mxnet.export_model(sym, params, [(8,3,416,416),(8,14196),(8,84,26,26),(8,84,26,26),(8,100,5),(8,2),(8,3)], np.float32, onnx_file) File "C:\Users\sunbo.conda\envs\mxnet-new\lib\site-packages\mxnet\contrib\onnx\mx2onnx\export_model.py", line 83, in export_model verbose=verbose) File "C:\Users\sunbo.conda\envs\mxnet-new\lib\site-packages\mxnet\contrib\onnx\mx2onnx\export_onnx.py", line 253, in create_onnx_graph_proto idx=idx File "C:\Users\sunbo.conda\envs\mxnet-new\lib\site-packages\mxnet\contrib\onnx\mx2onnx\export_onnx.py", line 90, in convert_layer raise AttributeError("No conversion function registered for op type %s yet." % op) AttributeError: No conversion function registered for op type LinearRegressionOutput yet.

No matter which loss function I use, there will be a similar error.

the loss function: rpn_bbox_loss = mx.sym.LinearRegressionOutput(name='rpn_bbox_loss', data=rpn_bbox_pred2, label=rpn_bbox_target2)

this is my convert code:

import mxnet as mx
import numpy as np
from mxnet.contrib import onnx as onnx_mxnet
import logging
logging.basicConfig(level=logging.INFO)

sym='./SNIPER-symbol.json'
params='./SNIPER-0002.params'
input_shape=(8,3,416,416)
onnx_file='./out.onnx'
converted_model_path = onnx_mxnet.export_model(sym, params, [(8,3,416,416),(8,14196),(8,84,26,26),(8,84,26,26),(8,100,5),(8,2),(8,3)], np.float32, onnx_file)

and this is my weights file and symbol file: SNIPER-symbol.json: https://drive.google.com/open?id=1oonndx3snHrJhpMbI8A1PhsHXp4EaJ-m SNIPER-0002.params https://drive.google.com/open?id=1fOyIxydTjlZHvXSTDkSzhuBAx3Ln1CB4

menet version: 1.4.0 python: 3.6.8 cuda: 9.0 cudnn: 7.5 system: windows10, 64bits

frankfliu commented 5 years ago

@mxnet-label-bot add [question, onnx]

QueensGambit commented 4 years ago

Hello @eeqmcc , I found a solution to this problem. The reason why you get this error is because your model was exported from the MXNet-symbol-API which includes loss functions within the network. When you define a network with the Gluon-API, the loss function is not part of the network definition, similar to other deep learning frameworks. The loss layers don't affect your output and are only used to calculate the gradient in the MXNet-symbol-API. Therefore, you can load your network into Gluon and export it again without the loss layers.

from onnx import checker
import onnx

sym = './SNIPER-symbol.json'
params = './SNIPER-0002.params'
input_shape = (8,3,416,416)
onnx_file = './out.onnx'

# check your last network output layer name by
# mx.viz.print_summary(
#        symbol,
#        shape={'data':(1, input_shape[0], input_shape[1], input_shape[2])},
#    )
# or reading the 'SNIPER-symbol.json' file
prob_out_name = "cls_score"
boxes_out_name = "bbox_pred"

symbol = mx.sym.load(sym)
inputs = mx.sym.var('data', dtype='float32')
prob_out = symbol.get_internals()[prob_out_name]
boxes_out = symbol.get_internals()[boxes_out_name]
sym = mx.symbol.Group([prob_out, boxes_out])
net = mx.gluon.SymbolBlock(sym, inputs)
net.collect_params().load(params, ctx=mx.cpu())
net.export("model")

Next, you should be able to export the network to ONNX:

onnx_model_path = onnx_mxnet.export_model("model-symbol.json", "model-0000.params",[input_shape], np.float32, onnx_file)

Now you can optionally validate the ONNX-model:

# validate model
# load the ONNX-model
model_proto = onnx.load_model(onnx_model_path)

# check if the converted ONNX-protobuf is valid
checker.check_graph(model_proto.graph)

You also need to make sure that your ONNX-version is <= 1.3.0.