onnx / keras-onnx

Convert tf.keras/Keras models to ONNX
Apache License 2.0
379 stars 110 forks source link

Can't convert keras.AdditiveAttention Layer #733

Open edmBernard opened 2 years ago

edmBernard commented 2 years ago

I try to convert a model using the additive attention layer from keras https://keras.io/api/layers/attention_layers/additive_attention/

I got the following error :

WARN: No corresponding ONNX op matches the tf.op node additive_attention/cond of type If
      The generated ONNX model needs run with the custom op supports.
WARN: No corresponding ONNX op matches the tf.op node keras_learning_phase of type PlaceholderWithDefault
      The generated ONNX model needs run with the custom op supports.
WARN: No corresponding ONNX op matches the tf.op node additive_attention/ReadVariableOp/resource of type Placeholder
      The generated ONNX model needs run with the custom op supports.
Traceback (most recent call last):
  File ".\script_attention.py", line 38, in <module>
    onnx_model = keras2onnx.convert_keras(model, model.name)
  File "D:\test\venv\lib\site-packages\keras2onnx\main.py", line 99, in convert_keras
    parse_graph(topology, tf_graph, target_opset, output_names, output_dict)
  File "D:\test\venv\lib\site-packages\keras2onnx\parser.py", line 907, in parse_graph
    ) if is_tf2 and is_tf_keras else _parse_graph_core(
  File "D:\test\venv\lib\site-packages\keras2onnx\parser.py", line 784, in _parse_graph_core_v2
    _on_parsing_tf_nodes(graph, layer_info.nodelist, varset, topology.debug_mode)
  File "D:\test\venv\lib\site-packages\keras2onnx\parser.py", line 328, in _on_parsing_tf_nodes
    var_type = infer_variable_type(i_, varset.target_opset)
  File "D:\test\venv\lib\site-packages\keras2onnx\_parser_tf.py", line 44, in infer_variable_type
    "Unable to find out a correct type for tensor type = {} of {}".format(tensor_type, tensor.name))
ValueError: Unable to find out a correct type for tensor type = 20 of additive_attention/ReadVariableOp/resource:0

I use this script to reproduce the issue :

def sa_ak_network(x):
    f = keras.layers.Conv2D(filters=x.shape[-1], kernel_size=1, strides=1)(x)
    g = keras.layers.Conv2D(filters=x.shape[-1], kernel_size=1, strides=1)(x)
    o = tf.keras.layers.AdditiveAttention()([f, g])
    o = keras.layers.Reshape(x.shape[1:])(o)
    gamma_o = keras.layers.Conv2D(filters=x.shape[-1], kernel_size=1, strides=1, use_bias=False)(o)
    return gamma_o + x

input_tmp = keras.Input(shape=(64, 64, 32), dtype=tf.float32)
output = sa_ak_network(input_tmp)

model = keras.Model(inputs=input_tmp, outputs=output)
model_json = json.loads(model.to_json())
with open("additive_keras.json", "w") as f:
    json.dump(model_json, f, indent=2)

res = model(np.ones((1, 64, 64, 32)))

onnx_model = keras2onnx.convert_keras(model, model.name)
keras2onnx.save_model(onnx_model, "additive_keras.onnx")

I tested it with various version of tensorflow + keras2onnx I got the same error. tensorflow 2.2 2.3 2.5 keras2onnx 1.6.5 1.7 and master