OpenPPL / ppq

PPL Quantization Tool (PPQ) is a powerful offline neural network quantization tool.
Apache License 2.0
1.58k stars 236 forks source link

关于量化配置的几点疑问 #460

Closed SunCrazy closed 1 year ago

SunCrazy commented 1 year ago
  1. 是只支持int8/fp8量化吗?是否支持uint8,int16?
  2. 如果weight和activation都想进行非对称per-tensor量化,要怎么配置?
  3. 是否支持更低bit的量化?例如2/3/4等

如果可以的话,能否麻烦提供一下示例,谢谢!

ZhangZhiPku commented 1 year ago

PPQ有着极其强大的量化配置功能,我相信这是市面上你能找到的最灵活的量化框架,但它也是高度复杂的。

  1. PPQ 支持int1-int16任意比特的线性量化,支持fp1-fp16任意比特的浮点量化,包括bf16等等,对于浮点量化你可以自由配置比特位的分配。对于更高比特精度的量化,会有很高概率出现数值溢出问题,就不推荐了。
  2. 我还没见过有人需要weight做非对称量化,如果w, a同时做非对称,推理应该贼几把难写,但PPQ确实可以这样量化...

您可以参考这个例子对量化器进行修改: https://github.com/openppl-public/ppq/blob/master/ProgramEntrance_2.py 您需要修改Quantizer的init_quant_config函数,他们负责创建量化控制信息TQC,您需要修改TQC的quant_min, quant_max, num_of_bits, policy等属性,其中policy属性以位图的方式控制量化策略。

我不建议您研究奇奇怪怪位宽的量化,例如2bit, 3bit或者混合精度推理。PPQ的量化逻辑是很复杂的,在网络的量化中我们会依次启动 QuantizeSimplifyPass, QuantizeFusionPass, ParameterQuantizePass, PassiveParameterQuantizePass, QuantAlignmentPass。这五大过程是PPQ量化的核心与精华所在,他们会像编译器那样对你的网络进行细致分析,如果你执行混合精度量化,他们的功能会受到影响,我难以保证他们的执行没有任何问题,混合量化在工程实现上是非常复杂的。如果您只是做学术研究,我建议您关闭上述优化过程,并在Quantizer中关闭所有算子的输出量化。

你可以在这里阅读量化优化过程相关的文档:https://github.com/openppl-public/ppq/tree/master/ppq/quantization/optim

SunCrazy commented 1 year ago

@ZhangZhiPku 非常感谢这么详细的回复。我这边尝试过之后,碰到下面的问题。

我的需求: 目标平台是onnxruntime,模型量化:w采用int8 sym per-channel,a采用int8 asym per-tensor, 同时bias需要量化到int32

为了避免自己写quantizer引入错误,我直接在onnxruntime的quantizer上进行了修改:ppq/quantization/quantizer/ORTQuantizer.py 如下:

from typing import Union

import torch
from ppq.core import (PASSIVE_OPERATIONS, OperationQuantizationConfig,
                      QuantizationPolicy, QuantizationProperty,
                      QuantizationStates, RoundingPolicy, TargetPlatform)
from ppq.IR import BaseGraph, Operation

from .base import BaseQuantizer

class OnnxruntimeQuantizer(BaseQuantizer):
    def __init__(
        self, graph: BaseGraph
    ) -> Union[torch.Tensor, list, dict]:
        super().__init__(graph=graph)
        self._num_of_bits = 8
        # self._quant_min = 0
        # self._quant_max = int(pow(2, self._num_of_bits) - 1)

        self._quant_min = -128
        self._quant_max = 127

    def init_quantize_config(
        self, operation: Operation) -> OperationQuantizationConfig:
        base_quant_config = self.create_default_quant_config(
            policy=self.quantize_policy, rounding=self.rounding_policy,
            op=operation, num_of_bits=self._num_of_bits, exponent_bits=0,
            quant_max=self._quant_max, quant_min=self._quant_min,
            observer_algorithm='percentile')

        if operation.type in {'Conv', 'ConvTranspose', 'Gemm'}:
            # set all parameters within Conv, ConvTranspose, Gemm to per-channel quant-config.
            assert operation.num_of_input > 0, 'Seems you got a Conv layer with no parameters.'

            # first parameter must exits, for conv layer it will be conv_weight
            # layout: [out_channel, in_channel, kernel_size, kernel_size]
            if operation.type in {'Conv', 'ConvTranspose'}:
                conv_weight_config = base_quant_config.input_quantization_config[1]
                conv_weight_config._quant_min = -128
                conv_weight_config._quant_max = 127
                conv_weight_config.policy = QuantizationPolicy(
                    QuantizationProperty.SYMMETRICAL +
                    QuantizationProperty.LINEAR +
                    QuantizationProperty.PER_CHANNEL
                )
                conv_weight_config.channel_axis = (1 if operation.type == 'ConvTranspose' else 0)
                conv_weight_config.observer_algorithm = 'minmax'
            # first parameter must exits, for gemm layer it will be gemm_weight
            # layout: [in_dim, out_dim]
            elif operation.type in {'Gemm'}:
                gemm_weight_config = base_quant_config.input_quantization_config[1]
                gemm_weight_config._quant_min = -128
                gemm_weight_config._quant_max = 127
                gemm_weight_config.policy = QuantizationPolicy(
                    QuantizationProperty.SYMMETRICAL +
                    QuantizationProperty.LINEAR +
                    QuantizationProperty.PER_CHANNEL
                )
                gemm_weight_config.channel_axis = 0
                gemm_weight_config.observer_algorithm = 'minmax'
            # if operation has bias
            if operation.num_of_input > 2:
                bias_config = base_quant_config.input_quantization_config[-1]
                bias_config.policy = QuantizationPolicy(
                    QuantizationProperty.SYMMETRICAL +
                    QuantizationProperty.LINEAR +
                    QuantizationProperty.PER_CHANNEL
                )
                # bias_config.state = QuantizationStates.FP32

                bias_config.num_of_bits = 30
                bias_config.quant_max = int(pow(2, 30 - 1))
                bias_config.quant_min = - int(pow(2, 30 - 1))
                bias_config.state = QuantizationStates.PASSIVE_INIT
                bias_config.channel_axis = 0
                bias_config.observer_algorithm = 'minmax'

        return base_quant_config

    @ property
    def quant_operation_types(self) -> set:
        QUANTTYPE = {
            'Conv', 'ConvTranspose', 'Gemm', 'Relu', 'PRelu',
            'Clip', 'Pad', 'Resize', 'MaxPool', 'AveragePool',
            'GlobalMaxPool', 'GlobalAveragePool', 'Softmax',
            'Mul', 'Add', 'Max', 'Sub', 'Div', 'Reshape',
            'LeakyRelu', 'Concat', 'Sigmoid', 'Interp',
            'ReduceMean', 'Transpose', 'Slice', 'Flatten',
            'HardSwish', 'HardSigmoid', 'MatMul'}
        QUANTTYPE.update(PASSIVE_OPERATIONS)
        return QUANTTYPE

    @ property
    def quantize_policy(self) -> QuantizationPolicy:
        return QuantizationPolicy(
            QuantizationProperty.ASYMMETRICAL +
            QuantizationProperty.LINEAR +
            QuantizationProperty.PER_TENSOR)

    @ property
    def rounding_policy(self) -> RoundingPolicy:
        return RoundingPolicy.ROUND_HALF_EVEN

    @ property
    def activation_fusion_types(self) -> set:
        return {'Relu', 'Clip', 'Sigmoid',
                'Swish', 'Mish', 'LeakyRelu'}

这里面的修改点主要是两个:

  1. 将默认的quantize_policy改成QuantizationProperty.ASYMMETRICAL
  2. 添加了bias的量化

按照上述修改后,导出TargetPlatform.ONNXRUNTIME的模型时,生成的quantized.onnx不对:

image

weight/bias虽然被量化了,但是a没有被量化。

接下来我进行了另外一个修改,将默认的quant_min和quant_max修改如下:

self._quant_min = 0
self._quant_max = int(pow(2, self._num_of_bits) - 1)

其他保持不变,最终量化后的模型时正确的,如下:

image

从结果看起来,如果指定为非对称量化(asym),是不是a的数据范围必须是[0, 255],也就是只能数据类型为uint8,不能是int8?

SunCrazy commented 1 year ago

另外,还有一个问题是:在指定平台为onnxruntime时,如果把bias量化为int32,那么最终export模型的时候会打印这种log:

[Warning] Exported Onnx Model is not executable, following Op has onnxruntime-unsupported quant policy:
[Warning] conv5_6/sep (bitwidth != 8)
[Warning] conv5_4/dw (bitwidth != 8)
[Warning] conv5_2/dw (bitwidth != 8)
[Warning] conv1 (bitwidth != 8)
[Warning] conv5_6/dw (bitwidth != 8)
[Warning] conv5_1/sep (bitwidth != 8)
[Warning] conv5_3/sep (bitwidth != 8)
[Warning] conv4_2/sep (bitwidth != 8)
[Warning] conv2_1/dw (bitwidth != 8)
[Warning] conv5_1/dw (bitwidth != 8)
[Warning] conv3_2/dw (bitwidth != 8)
[Warning] conv3_1/dw (bitwidth != 8)
[Warning] conv6/sep (bitwidth != 8)
[Warning] conv5_2/sep (bitwidth != 8)
[Warning] fc7 (bitwidth != 8)
...

而实际上,onnxruntime是可以正确执行这种量化后模型的

ZhangZhiPku commented 1 year ago

别别别,onnx标准里面那个qdq节点就必须是int8的,不能是其他位宽,这onnxruntime为啥能执行我也不知道 https://github.com/onnx/onnx/blob/main/docs/Operators.md

至于前面那个模型导出的问题, PPQ不会给你省掉activation的量化,但他的确不接受在非对称的情况下以-128, 127作为quant_min, quant_max,后面的calibration逻辑应该会出错。于此同时onnx qdq标准中int8必须是-128-127,uint8必须0-255,我记得导出的时候会检查,可能是那块检查逻辑发现你是非对称的,然后范围不是0~255,直接不导出了,他会给你警告。 针对这种需要调整量化策略的需求,你可以不用PPQ里面预制好的导出逻辑,他们都是有一些乱七八糟要求的,你可以自己写一个导出器,导出到txt或者json啥的就行。

SunCrazy commented 1 year ago

Hi,我看了一下上面链接里onnx ops的spec里的说明,除了uint8/int8,还允许DequantizeLinear是int32的:

image

我上面的case里面,由于只会把bias量化为int32,所以最后生成的模型里面对bias只会保留一个DequantizeLinear: image

SunCrazy commented 1 year ago

还有一个问题是,我原始模型中conv后面是有relu的,我看当我设置a的量化方式为asym+per-tensor,w的量化方式为sym+per-channel,且量化bias为int32之后(也就是上面提到的改动),生成出来的quantized onnx里面relu没有了,我想应该是某个优化pass把relu融合进qdq里面了。有什么办法可以避免融合而保留relu吗?

原始onnx模型: image

quantized onnx: image

ZhangZhiPku commented 1 year ago

非对称模式下relu和clip可以省略,很多推理框架压根不认这个非对称模式下的relu和clip,你可以在ppq.parser.OnnxruntimeExporter中找到移除它们的代码,你可以手动关闭那个过程。 https://github.com/openppl-public/ppq/blob/master/ppq/parser/onnxruntime_exporter.py

SunCrazy commented 1 year ago

非对称模式下relu和clip可以省略,很多推理框架压根不认这个非对称模式下的relu和clip,你可以在ppq.parser.OnnxruntimeExporter中找到移除它们的代码,你可以手动关闭那个过程。 https://github.com/openppl-public/ppq/blob/master/ppq/parser/onnxruntime_exporter.py

感谢大佬,已经找到了。 这两天试用了一下ppq,效果非常好,👍