onnx / tensorflow-onnx

Convert TensorFlow, Keras, Tensorflow.js and Tflite models to ONNX
Apache License 2.0
2.3k stars 432 forks source link

QDQ node for weight tensor of Con2D undergoes Constant folding (enabled for node using tf type=FakeQuantWithMinMaxVarsPerChannel) #1972

Open rado82 opened 2 years ago

rado82 commented 2 years ago

I am doing some experiment on using QAT for a sample model. Looks like QDQ node for the weight tensor of Conv operation is always folded during onnx generation.

Version of various packages are as follows: tensorflow version is 2.8.2 tf2onnx version is 1.11.1 tf model optimization toolkit version is 0.7.2

I am using tf model optimization to apply fake quantization nodes and using tf2onnx to convert the frozen graph from pb to onnx representation. I always get the weight tensor for the conv2d undergo constant folding during tf2onnx conversion. I can clearly see from the visualization of the frozen graph, there is a fake node introduced for weights.

To reproduce: Colab pynb Link: https://colab.research.google.com/drive/1Y_LhhWtJejv5teHgQslMPQdwebyHY1GD?usp=sharing

Netron vis of Pb file (fed as input to tf2onnx) github_0

Netron vis of Generated onnx github_1

Checking the previous issues here , I found this. Though tf.quantize_and_dequantize_v2 is used in earlier issue. Here I am using tf model optimization which uses other tf quantization API's

hwangdeyu commented 2 years ago

Hi @rado82, Thanks for the issue. It seems you want the Conv weight do not be constanted when the type is type=FakeQuantWithMinMaxVarsPerChannel. And it would be very helpful if you can provide the tf code that produced the pb model file.

mbrookhart commented 2 years ago

The TF code is attached in the collab notebook in the original post?

I just found this issue, I'm seeing similar behavior with a saved graphdef QAT model. I get a lot of INFO - folding node using tf type=FakeQuantWithMinMaxVars for my model with the latest tf2onnx. If I edit the script to disable tf constant node folding here and here I get this error instead:

ValueError: make_sure failure: Unable to convert node FakeQuantWithMinMaxArgs with narrow_range=1

I'll try to see if I can reduce this to a unit test. Thanks!