microsoft / onnxruntime

ONNX Runtime: cross-platform, high performance ML inferencing and training accelerator
https://onnxruntime.ai
MIT License
14.81k stars 2.94k forks source link

[Training] Whether to support weight per_channel QAT #19241

Open hbwx24 opened 10 months ago

hbwx24 commented 10 months ago

Describe the issue

The model weight is quantified per channel:

onnxruntime/orttraining/orttraining/training_api/module.cc:538 onnxruntime::common::Status onnxruntime::training::api::Module::TrainStep(const std::vector&, std::vector&) [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Non-zero status code returned while running FakeQuant node. Name:'FakeQuant_token_260' Status Message: /home/xin.wei/workdir/quant/onnx2torch/onnxruntime-source/orttraining/orttraining/training_ops/cpu/quantization/fake_quant.cc:68 onnxruntime::common::Status onnxruntime::contrib::FakeQuant::Compute(onnxruntime::OpKernelContext*) const [with T = float] IsScalarOr1ElementVector(scale) was false. Quantization scale must be a scalar or 1D tensor of size 1.


### To reproduce

The model weight is quantified per channel:
 - weight_scale.shape=[64,], 
 - zero_point.shape=[64]. 
When using onnxruntime-train to do QAT, the following error is reported. Does onnxruntime-train support per_channel QAT?

onnxruntime 1.16.3 onnxruntime-extensions 0.9.0 onnxruntime-gpu 1.16.3 onnxruntime-training 1.16.3

orttraining/orttraining/training_api/module.cc:538 onnxruntime::common::Status onnxruntime::training::api::Module::TrainStep(const std::vector&, std::vector&) [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Non-zero status code returned while running FakeQuant node. Name:'FakeQuant_token_260' Status Message: /home/xin.wei/workdir/quant/onnx2torch/onnxruntime-source/orttraining/orttraining/training_ops/cpu/quantization/fake_quant.cc:68 onnxruntime::common::Status onnxruntime::contrib::FakeQuant::Compute(onnxruntime::OpKernelContext*) const [with T = float] IsScalarOr1ElementVector(scale) was false. Quantization scale must be a scalar or 1D tensor of size 1.



### Urgency

_No response_

### ONNX Runtime Installation

Built from Source

### ONNX Runtime Version or Commit ID

 1.16.3

### PyTorch Version

1.10

### Execution Provider

Default CPU, CUDA

### Execution Provider Library Version

_No response_
baijumeswani commented 10 months ago

@hbwx24 Thanks for trying out QAT. QAT with ONNX Runtime is in experimental stage at this time.

Looking through my own TODOs in the repository, it seems like per channel QAT is not supported yet.

I don't know if I can commit to having this feature completed soon, but I will try to address this feature before the next ONNX Runtime release (1.18).

hbwx24 commented 10 months ago

@hbwx24 Thanks for trying out QAT. QAT with ONNX Runtime is in experimental stage at this time.

Looking through my own TODOs in the repository, it seems like per channel QAT is not supported yet.

I don't know if I can commit to having this feature completed soon, but I will try to address this feature before the next ONNX Runtime release (1.18).

Thank you very much

yzg216 commented 10 months ago

I have also encountered this problem, and I am anxious to use it. If I develop it myself, can you tell me how to fix it? @baijumeswani

github-actions[bot] commented 9 months ago

This issue has been automatically marked as stale due to inactivity and will be closed in 30 days if no further activity occurs. If further support is needed, please provide an update and/or more details.