KhronosGroup / NNEF-Tools

The NNEF Tools repository contains tools to generate and consume NNEF documents
https://www.khronos.org/nnef
222 stars 58 forks source link

Tflite to NNEF conversion produces bias variables of rank 1 #147

Closed matthieucoquet closed 3 years ago

matthieucoquet commented 3 years ago

I converted a simple graph from Keras (tf 2.4) -> tflite -> NNEF:

version 1.0;

graph main(external1) -> (linear2)
{
    external1 = external<scalar>(shape = [1, 64]);
    variable1 = variable<scalar>(shape = [32], label = 'variable1');
    variable2 = variable<scalar>(shape = [128], label = 'variable2');
    variable3 = variable<scalar>(shape = [32, 128], label = 'variable3');
    variable4 = variable<scalar>(shape = [128, 64], label = 'variable4');
    reshape1 = reshape(external1, shape = [1, 64]);
    linear1 = linear(reshape1, variable4, variable2);
    relu1 = relu(linear1);
    reshape2 = reshape(relu1, shape = [1, 128]);
    linear2 = linear(reshape2, variable3, variable1);
}

If I use the C++ parser on it and call infer_shapes, it produces an exception (out of range access to vector) during:

check(bias[1] == filter[0], "bias channels (%d) does not match filter count (%d)", (int)bias[1], (int)filter[0]); 

So I think the bias generated from the conversion are wrong (should be [1, 32] and [1, 128]). Is that correct?

Or is it a bug in infer_shapes?

Here is how I generated the tflite:

import tensorflow as tf
import os
import shutil

input = tf.keras.Input(shape=(64,), name="input", dtype=tf.float32)
hidden = tf.keras.layers.Dense(128, activation=tf.keras.activations.relu, name="hidden")(input)
output = tf.keras.layers.Dense(32, activation=None, name="output")(hidden)
model = tf.keras.Model(inputs=input, outputs=output)

# Save in tf format
directory = os.path.join("test")
model.save(os.path.join(directory, "tf"))

# Also save if tflite format
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
directory = os.path.join(directory, 'tflite') 
os.mkdir(directory)
with open(os.path.join(directory, 'model.tflite'), 'wb') as f:
    f.write(tflite_model)
matthieucoquet commented 3 years ago

The NNEF is correct if I use ONNX (using keras-onnx) or tf (using this) instead of tflite. So I found a way to export my model.

Should I close the issue? From tf.keras, tflite is the easiest way to export since tf's saved_model.pb doesn't work with nnef_tools.convert script.

gyenesvi commented 3 years ago

I believe this is an issue with the TFLite conversion, the correct shape should be [1,32] as you say. I'll investigate this, no need to close. It's understandable that it works through ONNX since that's a separate converter.

gyenesvi commented 3 years ago

Indeed, the conversion was wrong, I have pushed a fixed, can you check again with TFLite?

matthieucoquet commented 3 years ago

It works well now. Thank you!