onnx / tensorflow-onnx

Convert TensorFlow, Keras, Tensorflow.js and Tflite models to ONNX
Apache License 2.0
2.33k stars 432 forks source link

tf2onnx.convert.from_keras not mapping simple GRU correctly #1684

Open aawce opened 3 years ago

aawce commented 3 years ago

I have a simple Keras model (see below) with a GRU which keras2onnx was mapping just fine to onnx GRU. Now with tf2onnx.convert.from_keras I get a huge very complicated onnx model with loops and initializers other stuff and does not map to the onnx GRU. This breaks several inference backends.

Urgency Urgency is high - we have many customer models with GRU/LSTM/RNN that need to move to latest release now that keras2onnx is deprecated.

System information

To Reproduce In python, build the model below and run onnxmodel, = tf2onnx.convert.from_keras(model)

from tensorflow.keras.models import Model, load_model from tensorflow.keras.layers import Input, Dense, GRU, Dropout, Activation # dense-gru

model_in = Input(tuple(in_shape),batch_size=batch_size) x = Dense(192, activation='relu')(model_in) x = Dropout(0.5)(x) x = GRU(32, return_sequences=True)(x) x = Dropout(0.5)(x) model_out = Dense(outlen, activation='softmax')(x) model = Model(inputs=model_in, outputs=model_out)

Note: this model trains and works fine with keras2onnx v1.7 resulting in an .onnx model with a similar structure as the Keras definition (dropouts removed): model.summary() Model: "functional_1"


Layer (type) Output Shape Param #

input_1 (InputLayer) [(32, 16, 192)] 0


dense (Dense) (32, 16, 192) 37056


dropout (Dropout) (32, 16, 192) 0


gru (GRU) (32, 16, 32) 21696


dropout_1 (Dropout) (32, 16, 32) 0


dense_1 (Dense) (32, 16, 4) 132

Total params: 58,884 Trainable params: 58,884 Non-trainable params: 0


guschmue commented 3 years ago

yes, we are aware of it and will fix it.

TomWildenhain-Microsoft commented 3 years ago

1688 will solve the issue for GRUCell but unfortunately the pattern for a plain GRU layer is different.

aawce commented 2 years ago

Hi, What is ETA for this to be fixed?
Thanks