microsoft / onnxruntime

ONNX Runtime: cross-platform, high performance ML inferencing and training accelerator
https://onnxruntime.ai
MIT License
14.59k stars 2.92k forks source link

[Training] On device training doesn't work with INT8 Models #19078

Open IzanCatalan opened 9 months ago

IzanCatalan commented 9 months ago

Describe the issue

I am re-training some onnx models from ONNX Model Zoo Repo, especially quantised Resnet50 with INT8 datatype. However, when creating the artifacts according to onnx-runtime-training-examples Repo I get the following error:

Traceback (most recent call last):
  File "prepare_for_training.py", line 38, in <module>
    artifacts.generate_artifacts(
  File "/usr/local/lib/python3.8/dist-packages/onnxruntime/training/artifacts.py", line 152, in generate_artifacts
    _ = training_block(*[output.name for output in model.graph.output])
  File "/usr/local/lib/python3.8/dist-packages/onnxruntime/training/onnxblock/onnxblock.py", line 204, in __call__
    self._training_model, self._eval_model = _training_graph_utils.build_gradient_graph(
  File "/usr/local/lib/python3.8/dist-packages/onnxruntime/training/onnxblock/_training_graph_utils.py", line 127, in build_gradient_graph
    optimized_model = onnx.load_from_string(get_optimized_model(model.SerializeToString(), requires_grad, options))
RuntimeError: /home/onnxruntime/orttraining/orttraining/core/optimizer/qdq_fusion.cc:104 std::pair<bool, onnxruntime::Node*> onnxruntime::{anonymous}::CheckForQDQPatternMatch(onnxruntime::Graph&, onnxruntime::Node&, const onnxruntime::InlinedHashSet<std::basic_string_view<char> >&) graph_utils::IsSupportedOptypeVersionAndDomain(*dequantize_node_ptr, "DequantizeLinear", {10, 13}) && graph_utils::IsSupportedProvider(*dequantize_node_ptr, compatible_execution_providers) was false. Expected that every QuantizeLinear node be followed by a unique DequantizeLinear node. Actual: QuantizeLinear (data_QuantizeLinear) is followed by QLinearConv(fused resnetv17_conv0_fwd_quant).

I would like to know what to do to solve it. Is there any way of retraining or doing Transfer Learning with ORT ?

For helping, my code looks like this:

 frozen_params = []
requires_grad = []
for init in onnx_model.graph.initializer:
    if init.name.endswith("running_mean") or init.name.endswith("running_var"):
        frozen_params.append(init.name)
    elif init.name not in frozen_params:
        requires_grad.append(init.name)

print(len(requires_grad), len(frozen_params))
print(frozen_params)
# Generate the training artifacts.
artifacts.generate_artifacts(
   onnx_model,
   requires_grad=requires_grad,
   frozen_params=frozen_params,
   loss=artifacts.LossType.CrossEntropyLoss,
   optimizer=artifacts.OptimType.AdamW,
   artifact_directory=sys.argv[2]
)

eval_model = onnx.load(f"{sys.argv[2]}/eval_model.onnx")
eval_model.graph.output.append(onnx_model.graph.output[0])
onnx.save(eval_model, f"{sys.argv[2]}/eval_model2.onnx")

To reproduce

I am running onnxruntime build from source for cuda 11.2, GCC 9.5, cmake 3.27 and python 3.8 with ubuntu 20.04.

Urgency

As soon as possible

ONNX Runtime Installation

Built from Source

ONNX Runtime Version or Commit ID

onnxruntime-training 1.17.0+cu112

PyTorch Version

None

Execution Provider

CUDA

Execution Provider Library Version

Cuda 11.2

baijumeswani commented 9 months ago

Training will not work with a quantized model. How do you expect that training work with an INT8 model (backpropagation can only happen with floats).

The error you're hitting is a result of onnxruntime trying to convert your graph to a QAT graph. QAT with onnxruntime is still under experimental phase and we do not have complete support for it.

IzanCatalan commented 9 months ago

Thanks for the reply @baijumeswani . Yes, you are totally right, backpropagation cannot be done. I just hoped and asked, if there is a way of re-training a model using as you said QAT or Post Training Quantization with ORT. Will the support you mentioned soon be available, or is it a long-term plan?

Anyway, If I must re-train some models to INT8, as you said with ORT currently would be impossible, do you have any thoughts on how I could do it (using QAT, for instance) even with a different framework or AI Engine? Any help to clarify things would be highly appreciated.

Thank you.

baijumeswani commented 9 months ago

Yes, we will add some support for training a (fake) quantized model in some sense in the near to mid term. Maybe you can benefit from that. This is expected to be out in onnx runtime 1.18. Will keep you posted on that.

I am not aware of any framework that offers training of quantized models on the device. Sorry about that.