MPolaris / onnx2tflite

Tool for onnx->keras or onnx->tflite. Hope this tool can help you.
Apache License 2.0
504 stars 40 forks source link

转换onnx模型失败 #70

Open huanyingjun opened 2 months ago

huanyingjun commented 2 months ago

hi 我在转换 https://huggingface.co/BAAI/bge-small-en-v1.5/tree/main/onnx 这个模型,遇到下面的错:

2024-07-05 16:21:59.085678: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
Checking 0/1...
Traceback (most recent call last):
  File "/home/test/work/onnx2tflite/converter.py", line 140, in <module>
    run()
  File "/home/test/work/onnx2tflite/converter.py", line 123, in run
    onnx_converter(
  File "/home/test/work/onnx2tflite/converter.py", line 46, in onnx_converter
    keras_model = keras_builder(model_proto, native_groupconv)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/test/work/onnx2tflite/utils/builder.py", line 81, in keras_builder
    res = tf_operator(tf_tensor, onnx_weights, node_inputs, op_attr, outputs=node_outputs)(_inputs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/test/work/onnx2tflite/layers/deformation_layers.py", line 72, in __call__
    return tf.gather(inputs, self.indices, axis=self.axis)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/test/anaconda3/envs/onnx2tflite/lib/python3.11/site-packages/tensorflow/python/util/traceback_utils.py", line 153, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/home/test/anaconda3/envs/onnx2tflite/lib/python3.11/site-packages/keras/src/layers/core/tf_op_layer.py", line 119, in handle
    return TFOpLambda(op)(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/test/anaconda3/envs/onnx2tflite/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py", line 70, in error_handler
    raise e.with_traceback(filtered_tb) from None
    ^^^^^^^^^^^^^^^
TypeError: Exception encountered when calling layer "tf.gather" (type TFOpLambda).

Value passed to parameter 'indices' has DataType float32 not in list of allowed values: int16, int32, int64

Call arguments received by layer "tf.gather" (type TFOpLambda):
  • params=tf.Tensor(shape=(30522, 384), dtype=float32)
  • indices=tf.Tensor(shape=(1, 512), dtype=float32)
  • validate_indices=None
  • axis=0
  • batch_dims=0
  • name=None

请帮忙看一下,谢谢

MPolaris commented 2 months ago

看了一下,感觉转换成功概率不太大。 tensorflow lite不支持动态shape,你可以试着固定shape后再尝试。 因为通道对齐的问题,目前onnx2tflite对transformer架构的支持还比较薄弱,虽然有解决方案但是工程量比较大,所以不知道什么时候才能完成。