apache / mxnet

Lightweight, Portable, Flexible Distributed/Mobile Deep Learning with Dynamic, Mutation-aware Dataflow Dep Scheduler; for Python, R, Julia, Scala, Go, Javascript and more
https://mxnet.apache.org
Apache License 2.0
20.77k stars 6.8k forks source link

Failed to convert pytorch networks with "torch.view()" to mxnet with ONNX #14916

Open nicklhy opened 5 years ago

nicklhy commented 5 years ago

Hi, there! It seems that the current mxnet could not convert basic CNN models like alexnet, resnet from pytorch simply because the shape in ONNX is defined as a Tensor rather than attribute (which is mentioned here).

Are there any specific plans of solving this problem ?

Notice that the error could even occur when it is not a dynamic reshape. A simple script to produce is like below

import torch
import torch.nn as nn
import mxnet as mx
from mxnet.contrib import onnx as onnx_mxnet

class A(nn.Module):
    def __init__(self):
        super(A, self).__init__()
        pass

    def forward(self, x):
        return x.view(-1, 2)

if __name__ == "__main__":
    net = A()
    x = torch.randn(1, 3, 224, 224)
    torch.onnx.export(net, x, 'test_view.onnx', verbose=True)
    sym, arg_params, aux_params = onnx_mxnet.import_model('test_view.onnx')

Error message is

$ test_view_bug.py
graph(%0 : Float(1, 3, 224, 224)) {
  %1 : Tensor = onnx::Constant[value=-1  2 [ CPULongType{2} ]](), scope: A
  %2 : Float(75264, 2) = onnx::Reshape(%0, %1), scope: A
  return (%2);
}

Traceback (most recent call last):
  File "test_view_bug.py", line 23, in <module>
    sym, arg_params, aux_params = onnx_mxnet.import_model('test_view.onnx')
  File "/home/lhy/Documents/Lib/incubator-mxnet/python/mxnet/contrib/onnx/onnx2mx/import_model.py", line 59, in import_model
    sym, arg_params, aux_params = graph.from_onnx(model_proto.graph)
  File "/home/lhy/Documents/Lib/incubator-mxnet/python/mxnet/contrib/onnx/onnx2mx/import_onnx.py", line 116, in from_onnx
    mxnet_sym = self._convert_operator(node_name, op_name, onnx_attr, inputs)
  File "/home/lhy/Documents/Lib/incubator-mxnet/python/mxnet/contrib/onnx/onnx2mx/import_onnx.py", line 62, in _convert_operator
    op_name, new_attrs, inputs = convert_map[op_name](attrs, inputs, self)
  File "/home/lhy/Documents/Lib/incubator-mxnet/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py", line 462, in reshape
    reshape_shape = list(proto_obj._params[inputs[1].name].asnumpy())
KeyError: 'identity0'
mxnet-label-bot commented 5 years ago

Hey, this is the MXNet Label Bot. Thank you for submitting the issue! I will try and suggest some labels so that the appropriate MXNet community members can help resolve it. Here are my recommended labels: ONNX, Bug

vrakesh commented 5 years ago

@mxnet-label-bot add [ONNX]

@nicklhy Thanks for reporting this, we will look into this issue

nicklhy commented 5 years ago

After a deeper look at the source code. I noticed the convert function for "reshape" tries to read the "shape" values from the graph params dict at python/mxnet/contrib/onnx/onnx2mx/_op_translations.py#L462.

reshape_shape = list(proto_obj._params[inputs[1].name].asnumpy())

However, the "shape" tensor in ONNX has never been read and saved in that params dict. I managed to solve this problem by adding two lines listed below in python/mxnet/contrib/onnx/onnx2mx/import_onnx.py when performing the operator conversion

if op_name == 'Constant':
    self._params[mxnet_sym.name] = self._parse_array(node.attribute[0].t)

BTW, this may not work in the dynamic reshape case.

ahmed-shariff commented 5 years ago

@nicklhy is this fix being pushed to the repo?

nicklhy commented 5 years ago

@ahmed-shariff No, the above fix has not been pushed to the official repo yet.