KhronosGroup / NNEF-Tools

The NNEF Tools repository contains tools to generate and consume NNEF documents
https://www.khronos.org/nnef
222 stars 57 forks source link

Conversion of tf.nn.max_pool_with_argmax to NNEF fails #130

Closed dvorotnev closed 3 years ago

dvorotnev commented 3 years ago

Hello! In according to nnef_tools/operation_mapping.md the function tf.nn.max_pool_with_argmax must be converted in the max_pool_with_index NNEF operation. But when I try to save a simple tensorflow graph in a pb format:

import tensorflow as tf
import nnef_tools.io.tf.graphdef as graphdef

def testnet_max_pool_with_index():
    x = tf.placeholder(tf.float32, shape=[6, 32, 32, 3], name='input')
    return tf.nn.max_pool_with_argmax(x, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1], padding='VALID')

tf.reset_default_graph()
with tf.Session() as sess:
    result = testnet_max_pool_with_index()
    sess.run(tf.global_variables_initializer())
    graphdef.save_default_graph("model.pb", session=sess, outputs={result: "output"})

a python exception raises:

Traceback (most recent call last):
  File "test.py", line 14, in <module>
    graphdef.save_default_graph("model.pb", session=sess, outputs={result: "output"})
  File "/nix/store/z5wm5h1mdja6mzhr38vv87rm76z4f9yw-python3.7-nnef-tools-python-1.0/lib/python3.7/site-packages/nnef_tools/io/tf/graphdef/__init__.py", line 31, in save_default_graph
    output_names = [tensor.name[:-2] if tensor.name.endswith(':0') else tensor.name for tensor in outputs]
  File "/nix/store/z5wm5h1mdja6mzhr38vv87rm76z4f9yw-python3.7-nnef-tools-python-1.0/lib/python3.7/site-packages/nnef_tools/io/tf/graphdef/__init__.py", line 31, in <listcomp>
    output_names = [tensor.name[:-2] if tensor.name.endswith(':0') else tensor.name for tensor in outputs]
AttributeError: 'MaxPoolWithArgmax' object has no attribute 'name'

because the layer has two output tensors.

If I try to save only first result:

def testnet_max_pool_with_index():
    x = tf.placeholder(tf.float32, shape=[6, 32, 32, 3], name='input')
    return tf.nn.max_pool_with_argmax(x, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1], padding='VALID')[0]

then graph is saved successfully, but when I try to convert it:

python -m nnef_tools.convert --input-format=tf --output-format=nnef --input-model=./model.pb --output-model=model.nnef

converter says:

Conversion for operation type 'MaxPoolWithArgmax' is not implemented
gyenesvi commented 3 years ago

Hi, indeed, the conversion of MaxPoolWithArgmax was accidentally left out in the latest refactor, I have added it, it's working on my side.

However, there is a bug in your original TF code. As you note, tf.nn.max_pool_with_argmax has 2 results, but you assign them to a single python variable, which will be a tuple in this case, and try to give that to the saver, which fails as it is expecting tensors. Instead, you have to assign the result to two tensors, and give those to the saver separately:

maximum, index = testnet_max_pool_with_index()
graphdef.save_default_graph("model.pb", session=sess, outputs={maximum: "max", index: "idx"})

Let me know if it's working for you!

dvorotnev commented 3 years ago

Thank you for a quick response! Now it works fine, when I try to save only the first tensor. But when I try to save both it stil fails. Fixed tf example:

import tensorflow as tf
import nnef_tools.io.tf.graphdef as graphdef

def testnet_max_pool_with_index():
    x = tf.placeholder(tf.float32, shape=[6, 32, 32, 3], name='input')
    return tf.nn.max_pool_with_argmax(x, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1], padding='VALID')

tf.reset_default_graph()
with tf.Session() as sess:
    result1, result2 = testnet_max_pool_with_index()
    sess.run(tf.global_variables_initializer())
    graphdef.save_default_graph("model.pb", session=sess, outputs={result1: "output1", result2: "output2"})

Python output:

Traceback (most recent call last):
  File "./test.py", line 14, in <module>
    graphdef.save_default_graph("model.pb", session=sess, outputs={result1: "output1", result2: "output2"})
  File "/nix/store/7w23q55cvx8vahx07jdf7cjjz75w1k43-python3.7-nnef-tools-python-1.0/lib/python3.7/site-packages/nnef_tools/io/tf/graphdef/__init__.py", line 34, in save_default_graph
    graph_def = graph_util.convert_variables_to_constants(session, graph_def, output_names)
  File "/nix/store/6b76c6ri67hi7z13iipilr7lnx4l8m2p-python3-3.7.9-env/lib/python3.7/site-packages/tensorflow/python/util/deprecation.py", line 324, in new_func
    return func(*args, **kwargs)
  File "/nix/store/6b76c6ri67hi7z13iipilr7lnx4l8m2p-python3-3.7.9-env/lib/python3.7/site-packages/tensorflow/python/framework/graph_util_impl.py", line 270, in convert_variables_to_constants
    inference_graph = extract_sub_graph(input_graph_def, output_node_names)
  File "/nix/store/6b76c6ri67hi7z13iipilr7lnx4l8m2p-python3-3.7.9-env/lib/python3.7/site-packages/tensorflow/python/util/deprecation.py", line 324, in new_func
    return func(*args, **kwargs)
  File "/nix/store/6b76c6ri67hi7z13iipilr7lnx4l8m2p-python3-3.7.9-env/lib/python3.7/site-packages/tensorflow/python/framework/graph_util_impl.py", line 182, in extract_sub_graph
    _assert_nodes_are_present(name_to_node, dest_nodes)
  File "/nix/store/6b76c6ri67hi7z13iipilr7lnx4l8m2p-python3-3.7.9-env/lib/python3.7/site-packages/tensorflow/python/framework/graph_util_impl.py", line 137, in _assert_nodes_are_present
    assert d in name_to_node, "%s is not in graph" % d
AssertionError: MaxPoolWithArgmax:1 is not in graph
gyenesvi commented 3 years ago

This seems to be a bug in TF (save_default_graph runs a TF utility, graph_util.convert_variables_to_constants, which fails here, and I guess I suspect why; the graphdef can not properly enumerate its outputs). This only happens if MaxPoolWithArgmax is the last op. Anyway, you can work around it by adding a tf.identity on the second output, then it works for me.

I'll try to add a workaround into save_default_graph to automate this without messing up output names.

gyenesvi commented 3 years ago

I have added the workaround to save_default_graph it should be working now even if you don't add tf.identity manually.

dvorotnev commented 3 years ago

The model is saved and converted well now! Thank you very much!

gyenesvi commented 3 years ago

Great, closing the issue.