KhronosGroup / NNEF-Tools

The NNEF Tools repository contains tools to generate and consume NNEF documents
https://www.khronos.org/nnef
222 stars 58 forks source link

Incorrect conversion of tf.nn.local_response_normalization function #137

Closed dvorotnev closed 4 years ago

dvorotnev commented 4 years ago

I am trying to save and to convert a simple python network:

import tensorflow as tf
import nnef_tools.io.tf.graphdef as graphdef

def testnet_local_response_normalization():
    x = tf.placeholder(tf.float32, shape=[3, 4, 5, 6], name='input')
    return tf.nn.local_response_normalization(x, depth_radius=2, bias=1, alpha=1/5, beta=0.5)

tf.reset_default_graph()
with tf.Session() as sess:
    result = testnet_local_response_normalization()
    sess.run(tf.global_variables_initializer())
    graphdef.save_default_graph("model.pb", session=sess, outputs={result: "output"})

using commands:

python ./test.py 
python -m nnef_tools.convert --input-format=tf --output-format=nnef --input-model=./model.pb --output-model=model.nnef

the conversion result is:

version 1.0;

graph G(external1) -> (copy1)
{
    external1 = external<scalar>(shape = [3, 4, 5, 6]);
    local_response_normalization1 = local_response_normalization(external1, size = [1, 2, 1, 1], alpha = 0.20000000298023224, beta = 0.5, bias = 1.0);
    copy1 = copy(local_response_normalization1);
}

But the documentation for the tf.nn.local_response_normalization function has some differences from the NNEF documentation for the local_response_normalization operation:

  1. In tf normalization is done on channel dimension, that equals 6 in python example
  2. In tf normalization size is set by depth_radius, not by full normalization size
  3. In tf a squared sum is not divided by normalization size

According to these differences the conversion result should be:

version 1.0;

graph G(external1) -> (transpose2)
{
    external1 = external<scalar>(shape = [3, 4, 5, 6]);
    transpose1 = transpose(external1, axes = [0, 3, 1, 2]);
    local_response_normalization1 = local_response_normalization(transpose1, size = [1, 5, 1, 1], alpha = 1.0000000149011612, beta = 0.5, bias = 1.0);
    copy1 = copy(local_response_normalization1);
    transpose2 = transpose(copy1, axes = [0, 2, 3, 1]);
}

I have created a pull request that fixes these differences, but I am not sure about correctness of a tensors transposing.

gyenesvi commented 4 years ago

Thanks for catching these differences, I remember these but forgot about them in the latest converter migration. I added a review to the pull request.

gyenesvi commented 4 years ago

In the meantime I have realized that this is not the right approach to fix the problem of transposing. The NNEF version of LRN is independent of the dimension order, it can do the normalization on any (number of) dimensions, so there is no need to force the channel to be the second dim. Instead, the converter needs to check which is the channel dim and generate the size accordingly.

Furthermore, the NNEF -> TF direction has to be fixed as well, along with the TFLite versions for consistency.

It's easier for me to fix these myself than to explain them for you to fix the pull request. Can you please first just file the bugs and let me fix them quickly if they are just small things?

gyenesvi commented 4 years ago

I have added a fix, that should handle the size properly, can you check now?

dvorotnev commented 4 years ago

I checked a new fix. Now a conversion result is:

version 1.0;

graph G(external1) -> (copy1)
{
    external1 = external<scalar>(shape = [3, 4, 5, 6]);
    local_response_normalization1 = local_response_normalization(external1, size = [1, 1, 1, 5], alpha = 1.0000000149011612, beta = 0.5, bias = 1.0);
    copy1 = copy(local_response_normalization1);
}

This is according to the tf documentation. Thank you!

gyenesvi commented 4 years ago

Great! Also, if you turn on --optimize, the copy should disappear from the end.