onnx / onnx-tensorflow

Tensorflow Backend for ONNX
Other
1.29k stars 296 forks source link

ValueError: Dimensions must be equal, but are 1 and 7 for ... = MatMul ... with input shapes: [7,1], [7,100] #883

Open Ark-kun opened 3 years ago

Ark-kun commented 3 years ago

Describe the bug

I cannot convert a simple dense network from PyTorch to Tensorflow through ONNX.

ValueError: Dimensions must be equal, but are 1 and 7 for '{{node onnx_tf_prefix_If_4/MatMul}} = MatMul[T=DT_FLOAT, transpose_a=false, transpose_b=false](onnx_tf_prefix_If_4/flatten/ExpandDims, onnx_tf_prefix_If_4/transpose)' with input shapes: [7,1], [7,100].

To Reproduce

import onnx
import onnx_tf

onnx_model = onnx.load("model.onnx")
tf_rep = onnx_tf.backend.prepare(onnx_model)
tf_rep.export_graph("model.tf")

ONNX model file

model.onnx.as.zip

Python, ONNX, ONNX-TF, Tensorflow version

This section can be obtained by running get_version.py from util folder.

Additional context

Here is how I've converted the PyTorch model to ONNX (used 'pytorch/pytorch:1.7.1-cuda11.0-cudnn8-runtime'):

def convert_to_onnx_from_pytorch_script_module(
    model_path,
    converted_model_path,
    list_of_input_shapes: list,
):
    import torch
    model = torch.jit.load(model_path)
    example_inputs = [
        torch.ones(*input_shape)
        for input_shape in list_of_input_shapes
    ]
    example_outputs = model.forward(*example_inputs)
    torch.onnx.export(
        model=model,
        args=example_inputs,
        f=converted_model_path,
        verbose=True,
        training=torch.onnx.TrainingMode.EVAL,
        example_outputs=example_outputs,
    )
chinhuang007 commented 3 years ago

Can't open/unzip the onnx model file. Please double check and share again. Thanks.

Ark-kun commented 3 years ago

Can't open/unzip the onnx model file. Please double check and share again. Thanks.

@chinhuang007 You do not need to unzip it. Just rename it to model.onnx.

chinhuang007 commented 3 years ago

The onnx model doesn't seem to be valid. The sub-graph for node "If_4" has a Gemm node that can't pass ONNX runtime shape inference validation. You can run this simple code to see the error.

import onnx
import onnxruntime.backend as ort

onnx_path = 'model.onnx'
model = onnx.load(onnx_path)
rt_rep = ort.prepare(model)