onnx / keras-onnx

Convert tf.keras/Keras models to ONNX
Apache License 2.0
379 stars 109 forks source link

tensorflow max_pool_with_argmax op does not return indices #699

Open mpaillassa opened 3 years ago

mpaillassa commented 3 years ago

Hello, I think there is a bug with the tensorflow max_pool_with_argmax op. When running the op with onnxruntime, I find that the op returns the pooled values twice instead of the pooled values and the pooled indices. Here is a code reproducing the bug:

import tensorflow as tf
import numpy as np

class Bug(tf.keras.Model):   
    def __init__(self):
        super(Bug, self).__init__()
    def call(self, inputs):
        v, i = tf.nn.max_pool_with_argmax(inputs, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
        return tf.cast(i, tf.float32)

b = Bug()
np.random.seed(0)
inp = np.random.uniform(0, 10, (2, 4, 4, 3))
keras_indices = b(inp)

import onnx
import keras2onnx

onnx_model = keras2onnx.convert_keras(b, target_opset=12)
onnx.save_model(onnx_model, "test.onnx")

import onnxruntime as rt

session = rt.InferenceSession("test.onnx")
input, output = session.get_inputs()[0], session.get_outputs()[0]
onnx_indices = session.run([output.name], {input.name: inp.astype(np.float32)})[0]

print(inp[0,:,:,2])
print(onnx_indices[0,:,:,2])

In the end it produces:

[[6.02763376 6.45894113 9.63662761 5.2889492 ]
 [0.71036058 8.32619846 9.78618342 7.80529176]
 [1.43353287 4.1466194  4.56150332 6.17635497]
 [9.43748079 4.37031954 6.66766715 1.28926298]]
[[8.326199  9.786183 ]
 [9.437481  6.6676674]]

Note that I had to cast the resulting indices as float otherwise I get the following error during inference:

FAIL : Load model from test.onnx failed:Type Error: Type (tensor(int64)) of output arg (Identity:0) of node (bug/MaxPoolWithArgmax_transpose_2_1) does not match expected type (tensor(float)).

This was obtained on ubuntu 20.04, with tensorflow 2.4.1, onnx 1.8.1, keras2onnx 1.8.0, onnxruntime 1.7.0, onnxconverter-common 1.8.0.