onnx / onnx-tensorflow

Tensorflow Backend for ONNX
Other
1.26k stars 298 forks source link

torch->onnx->tf with error: ValueError: Node 'onnx_tf_prefix_Unsqueeze_38': Unknown input node 'onnx_tf_prefix_Not_31' #1057

Closed yangqinj closed 1 year ago

yangqinj commented 1 year ago

Describe the bug

I converted the fastspeech2 model from torch to onnx and then to tf. After solving some problems, the .pb file is produced. But I got this error when importing the tensorflow graph:

Traceback (most recent call last):
  File "/mnt/data10/user/install/miniconda3/envs/py37/lib/python3.7/site-packages/tensorflow_core/python/framework/importer.py", line 501, in _import_graph_def_internal
    graph._c_graph, serialized, options)  # pylint: disable=protected-access
tensorflow.python.framework.errors_impl.InvalidArgumentError: Node 'onnx_tf_prefix_Unsqueeze_38': Unknown input node 'onnx_tf_prefix_Not_31'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "convert_pt_to_tf_pb_am.py", line 171, in <module>
    tf.import_graph_def(output_graph_def, name="")
  File "/mnt/data10/user/install/miniconda3/envs/py37/lib/python3.7/site-packages/tensorflow_core/python/util/deprecation.py", line 507, in new_func
    return func(*args, **kwargs)
  File "/mnt/data10/user/install/miniconda3/envs/py37/lib/python3.7/site-packages/tensorflow_core/python/framework/importer.py", line 405, in import_graph_def
    producer_op_list=producer_op_list)
  File "/mnt/data10/user/install/miniconda3/envs/py37/lib/python3.7/site-packages/tensorflow_core/python/framework/importer.py", line 505, in _import_graph_def_internal
    raise ValueError(str(e))
ValueError: Node 'onnx_tf_prefix_Unsqueeze_38': Unknown input node 'onnx_tf_prefix_Not_31'

To Reproduce

torch.onnx.export(model,
                      args=(dummy_input_1, dummy_input_2, dummy_input_3, dummy_input_4),
                      f=onnx_model_path,
                      opset_version=11,
                      input_names=input_names,
                      output_names=output_names)

model_onnx = onnx.load(onnx_model_path)

tf_rep = prepare(model_onnx)
tf_rep.export_graph(tf_model_path)

with tf.Graph().as_default():
  with open(tf_model_path, 'rb') as f:
      output_graph_def = tf.GraphDef.FromString(f.read()
      tf.import_graph_def(output_graph_def, name="")

ONNX model file

Sorry. For some network limitation, I cannot upload the onnx file.

Python, ONNX, ONNX-TF, Tensorflow version

This section can be obtained by running get_version.py from util folder.

Additional context

After checking the onnx file, I found the node onnx_tf_prefix_Unsqueeze_38 and onnx_tf_prefix_Not_31 correspond to this torch code:

mask = ids >= lengths.unsqueeze(1).expand(-1, max_len)

the onnx_tf_prefix_Not_31 is correspond the the >= which is converted to not less operator in onnx. Then, I try to fix this problem by:

then I got this error:

ValueError: Node 'onnx_tf_prefix_Unsqueeze_42': Unknown input node 'onnx_tf_prefix_Cast_35'

One important thing is that I cannot use tf-2.x, I must use tf-1.15.4.