Open lix19937 opened 4 months ago
1.Why does QAT need to calibrate the model ? Can I skip calibrate in QAT by set num_calib_batch=0
?
TensorRT QAT tools' recommand solution procedure while QAT is(ref: http://arxiv.org/abs/2004.09602 figure5)(see attachment qatworkflow.png) 1) Do Calibration first which can get the max-dynamic-range of each tensor.
2) Stick the max(dynamic range) of each tensor and start training to finetune the network.
So, The purpose of "calibration before finetune" is "Get the amax of each tensor before training". 2.
Why not specify parameter of quantize ? And set quantize=True ?
What is quant_modules.initialize()
purpose ?
What is quant_modules.deactivate()
purpose ?
ref : https://docs.nvidia.com/deeplearning/tensorrt/pytorch-quantization-toolkit/docs/tutorials/quant_resnet50.html#adding-quantized-modules TensorRT/quant_modules.py at release/8.2 · NVIDIA/TensorRT (github.com) TensorRT/quant_modules.py at release/8.2 · NVIDIA/TensorRT (github.com) TensorRT/quant_modules.py at 6f38570b74066ef464744bc789f8512191f1cbc0 · NVIDIA/TensorRT (github.com)
1) The purpose of quant_modules.initialize() is :"enable automaitic layer substitution via monkey-patching", All the layers can be quante can be replaced with the corresponding quant-layers 2) The purpose of quant_modules.deactivate() is disable "quant_modules.initialize() " you can oberserve the model infomation using the below sample code:
model_fp32 = torchvision.models.dictmodel_name quant_modules.initialize() model_fake_quant = torchvision.models.dictmodel_name quant_modules.deactivate() model_fp32_2 = torchvision.models.dictmodel_name print(model_fp32) print(model_fp32_2) print(model_fake_quant)3) the quantizae parameters (for example pertensor-quant or per-channel quant) is specified by the code (static value of the quant_module class):
if per_channel_quantization: quant_desc_input = QuantDescriptor(calib_method=calibrator) quant_nn.QuantConv2d.set_default_quant_desc_input(quant_desc_input) quant_nn.QuantLinear.set_default_quant_desc_input(quant_desc_input) else:
quant_desc_input = QuantDescriptor(calib_method=calibrator, axis=None) quant_nn.QuantConv2d.set_default_quant_desc_input(quant_desc_input) quant_nn.QuantConvTranspose2d.set_default_quant_desc_input(quant_desc_input) quant_nn.QuantLinear.set_default_quant_desc_input(quant_desc_input)
quant_desc_weight = QuantDescriptor(calib_method=calibrator, axis=None) quant_nn.QuantConv2d.set_default_quant_desc_weight(quant_desc_weight) quant_nn.QuantConvTranspose2d.set_default_quant_desc_weight(quant_desc_weight) quant_nn.QuantLinear.set_default_quant_desc_weight(quant_desc_weight)
Why self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
in https://github.com/NVIDIA/TensorRT/blob/main/tools/pytorch-quantization/examples/torchvision/models/classification/resnet.py#L237 MaxPool2d not executed by quantizate ?
pls ref : https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#work-with-qat-networks
MaxPooling is a commute layer which meet Q(maxpooling(x)) == maxpooling(Q(x)). So, while Q backward & DQ forward, Q&DQ can go across the maxpooling layer, And while the onnx model was import to TensorRT, TensorRT ensure the Maxpooling will also be quanted and run in int8 mode
for more documents, pls ref: Quantizing Resnet50 — pytorch-quantization master documentation (nvidia.com)
[SAIC_IDC_Bug Only] Hi haoran, Thanks for your reply. Why self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
in https://github.com/NVIDIA/TensorRT/blob/main/tools/pytorch-quantization/examples/torchvision/models/classification/resnet.py#L237 MaxPool2d not executed by quantizate ?
In
https://github.com/NVIDIA/TensorRT/blob/main/tools/pytorch-quantization/examples/torchvision/classification_flow.py
Why does QAT need to calibrate the model ? Can I skip calibrate in QAT by set
num_calib_batch=0
?In https://github.com/NVIDIA/TensorRT/blob/main/tools/pytorch-quantization/examples/torchvision/classification_flow.py line 152 ~154
Why not specify parameter of quantize ? And set quantize=True ?
What is
quant_modules.initialize()
purpose ?What is
quant_modules.deactivate()
purpose ?