onnx / onnx-tensorflow

Tensorflow Backend for ONNX
Other
1.29k stars 296 forks source link

Reshape is not converting in some cases. #756

Open andrew-yang0722 opened 4 years ago

andrew-yang0722 commented 4 years ago

Converting is not working in onnx::Reshape() which is converted from .view() of pytorch. I got 0 dim output by When infer the tflite modelTF Lite on android. That is tflite.getOutputTensor().numDimensions() is 0.

I think converting failed.

I make a sample codes that shows this problem.

Pytorch codes

class ReshapeModel(nn.Module):
    def __init__(self):
        super(ReshapeModel, self).__init__()

    def forward(self, x):
        out = x.view(1, -1, 4).contiguous()
        return out

dummy_input =  torch.randn(1, 56, 56, 2)

Onnx log

graph(%input : Float(1, 56, 56, 2)):
  %1 : Tensor = onnx::Constant[value= 1 -1  4 [ CPULongType{3} ]]()
  %output : Float(1, 1568, 4) = onnx::Reshape(%input, %1) # simple_onnx.py:28:0
  return (%output)

In some cases, reshape is working, and some cases it fails. Fail case: torch.randn(1, 56, 56, 2) -> x.view(1, 56, -1, 4) Fail case: torch.randn(1, 56, 56, 2) -> x.view(1, -1, 4) OK case: torch.randn(1, 56, 2) -> x.view(1, -1) OK case: torch.randn(1, 56, 56, 2) -> x.view(1, -1)

tensorflow 1.15.0 onnx 1.7.0 onnx-tf 1.6.0 tf-1.x branch torch 1.4.0

chinhuang007 commented 4 years ago

Can you provide the onnx file that is generated from pytorch code?

andrew-yang0722 commented 4 years ago

Here is the file. onnx_file