microsoft / onnxruntime

ONNX Runtime: cross-platform, high performance ML inferencing and training accelerator
https://onnxruntime.ai
MIT License
14.26k stars 2.87k forks source link

[Feature Request] support FP8 calibraion method and quantization #21090

Open zccyman opened 3 months ago

zccyman commented 3 months ago

Describe:

I set weight/activation with QuantType.QFLOAT8E4M3FN when calling quantize_static, but I get the following errors:

Traceback (most recent call last):
  File "/home/developer/workspace/code/onnxruntime/examples/only_one_conv/run.py", line 102, in <module>
    main()
  File "/home/developer/workspace/code/onnxruntime/examples/only_one_conv/run.py", line 71, in main
    quantize_static(
  File "/root/miniconda3/envs/ort/lib/python3.10/site-packages/onnxruntime/quantization/quantize.py", line 441, in quantize_static
    raise ValueError("Only Distribution calibration method is supported for float quantization.")
ValueError: Only Distribution calibration method is supported for float quantization.

so I think FP8 calibration and quantization is not supported so far ?

yufenglee commented 3 months ago

@zccyman, you can specify the calibration method to CalibrationMethod.Distribution

https://github.com/microsoft/onnxruntime/blob/01279d889671363008e91fd76917864b36b51cdd/onnxruntime/python/tools/quantization/quantize.py#L440C32-L440C62

zccyman commented 3 months ago

@zccyman, you can specify the calibration method to CalibrationMethod.Distributionā†³

https://github.com/microsoft/onnxruntime/blob/01279d889671363008e91fd76917864b36b51cdd/onnxruntime/python/tools/quantization/quantize.py#L440C32-L440C62

It's right. but I find a bug of MaxPool during fp8 quantization? the following error what I get:

benchmarking quant model...
Traceback (most recent call last):
  File "/home/developer/workspace/code/onnxruntime/examples/only_one_conv/run.py", line 110, in <module>
    main()
  File "/home/developer/workspace/code/onnxruntime/examples/only_one_conv/run.py", line 106, in main
    benchmark(output_model_path)
  File "/home/developer/workspace/code/onnxruntime/examples/only_one_conv/run.py", line 15, in benchmark
    session = onnxruntime.InferenceSession(model_path)
  File "/root/miniconda3/envs/ort/lib/python3.10/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 419, in __init__
    self._create_inference_session(providers, provider_options, disabled_optimizers)
  File "/root/miniconda3/envs/ort/lib/python3.10/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 483, in _create_inference_session
    sess.initialize_session(providers, provider_options, disabled_optimizers)
onnxruntime.capi.onnxruntime_pybind11_state.InvalidGraph: [ONNXRuntimeError] : 10 : INVALID_GRAPH : This is an invalid model. Type Error: Type 'tensor(float8e4m3fn)' of input parameter (input_QuantizeLinear_Output) of operator (MaxPool) in node (/MaxPool) is invalid.