microsoft / onnxruntime

ONNX Runtime: cross-platform, high performance ML inferencing and training accelerator
https://onnxruntime.ai
MIT License
14.61k stars 2.92k forks source link

yolov3-tiny model float16 quantization (InvalidArgument: [ONNXRuntimeError] ) #12152

Open jiyoungAn opened 2 years ago

jiyoungAn commented 2 years ago

Hi everyone,

I quantized yolov3-tiny model with float16 and run the model in onnxruntime

ort_session = ort.SessionOptions()
ort_session.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL
sess = ort.InferenceSession(model_path, ort_session, providers=['CPUExecutionProvider'])
image = np.ones(shape=(1, 3, 416, 416), dtype=np.float16)
image_size = np.random.rand(1, 2).astype('float16')
input_name1 = sess.get_inputs()[0].name
input_name2 = sess.get_inputs()[1].name
output = sess.run(None, {input_name1: image, input_name2: image_size})

But I have this issue : InvalidGraph: [ONNXRuntimeError] : 10 : INVALID_GRAPH : Load model from ./tiny_yolov3_fp16.onnx failed:This is an invalid model. Type Error: Type 'tensor(float16)' of input parameter (yolo_evaluation_layer_1/concat_6:0_btc) of operator (NonMaxSuppression) in node (yolonms_layer_1/non_max_suppression/NonMaxSuppressionV3) is invalid.

Is there way I can fix this non_max_suppresion problem? I would really appreciate it if you give me any idea. Thank you for your time.

model file : tiny_yolov3_fp16.zip

chenfucn commented 2 years ago

Currently this operator only supports float 32

https://github.com/onnx/onnx/blob/main/docs/Operators.md#NonMaxSuppression

You are more than welcome to implement a float16 version.

tianleiwu commented 2 years ago

Another solution is to exclude the operator in float16 conversion using op_block_list=['NonMaxSuppression'] in the following function: https://github.com/microsoft/onnxconverter-common/blob/0a401de9ee410bf3f65fb3dd3d13d4eab7e91a10/onnxconverter_common/float16.py#L91