microsoft / onnxruntime

ONNX Runtime: cross-platform, high performance ML inferencing and training accelerator
https://onnxruntime.ai
MIT License
14.84k stars 2.94k forks source link

Static quantize self-attention module not work #17278

Open imzhuhl opened 1 year ago

imzhuhl commented 1 year ago

Describe the issue

I am testing the inference performance of a model based on multi-head self attention. After I turn on static quantization, I find that the performance dropped instead. Then, I write a simple test and find that the self-attention graph is strange after static quantization.

Here is the simple reproducd code:

import math
import time
import numpy as np
import torch
import torch.nn as nn
import onnx
import onnxruntime
from onnxruntime.quantization import quantize_dynamic, quantize_static, CalibrationDataReader
from onnxruntime.quantization.quant_utils import QuantFormat

# multi head self attn by pytorch
class SelfAttn(nn.Module):
    def __init__(self, hidden_size, num_attn_heads):
        super().__init__()

        attn_head_size = int(hidden_size / num_attn_heads)
        all_head_size = num_attn_heads * attn_head_size

        self.query = nn.Linear(hidden_size, all_head_size)
        self.key = nn.Linear(hidden_size, all_head_size)
        self.value = nn.Linear(hidden_size, all_head_size)
        self.attn_head_size = attn_head_size
        self.num_attn_heads = num_attn_heads

    def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
        new_x_shape = x.size()[:-1] + (self.num_attn_heads, self.attn_head_size)
        x = x.view(new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(
        self,
        hidden_states: torch.Tensor,
    ):
        mixed_query_layer = self.query(hidden_states)
        mixed_key_layer = self.key(hidden_states)
        mixed_value_layer = self.value(hidden_states)

        query_layer = self.transpose_for_scores(mixed_query_layer)
        key_layer = self.transpose_for_scores(mixed_key_layer)
        value_layer = self.transpose_for_scores(mixed_value_layer)

        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))

        attention_scores = attention_scores / math.sqrt(self.attn_head_size)
        attention_probs = nn.functional.softmax(attention_scores, dim=-1)

        context_layer = torch.matmul(attention_probs, value_layer)

        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()  # (batch, seq_len, num_heads, dim)
        new_context_layer_shape = context_layer.size()[:-2] + (self.num_attn_heads * self.attn_head_size,)
        context_layer = context_layer.view(new_context_layer_shape)  # (batch, seq_len, hidden_size)

        return context_layer

def export_model():
    model = SelfAttn(128, 2)
    x = torch.randn(2, 10, 128)
    tuple_input = (x,)

    torch.onnx.export(
        model,
        tuple_input,
        f="attn.onnx",
        input_names=['input_tensor'],
        output_names=['output_tensor'],
        dynamic_axes={'input_tensor': {0: 'batch_size', 1: 'seq_len'},
                      'output_tensor': {0: 'batch_size', 1: 'seq_len'}}
    )
    quantize_dynamic("attn.onnx", "attn_dynamic_quant.onnx")

class FakeDataReader(CalibrationDataReader):
    def __init__(self, model_path):
        session = onnxruntime.InferenceSession(model_path)
        self.input_name = session.get_inputs()[0].name

        fake_data = torch.randn(4, 1, 10, 128).numpy()
        self.datasize = fake_data.shape[0]
        self.fake_data = iter(
            [{self.input_name: fake_data[i]} for i in range(self.datasize)]
        )

    def get_next(self):
        return next(self.fake_data, None)

def export_static_quant_model():
    dr = FakeDataReader("attn_preprocess.onnx")
    quantize_static("attn_preprocess.onnx", "attn_static_quant.onnx", dr, quant_format=QuantFormat.QDQ)

def run(name, model_path, x):
    ort_session = onnxruntime.InferenceSession(model_path)

    ort_inputs = {ort_session.get_inputs()[0].name: x}

    for _ in range(5):
        outputs = ort_session.run(output_names=None, input_feed=ort_inputs)

    t0 = time.time()
    for _ in range(100):
        outputs = ort_session.run(output_names=None, input_feed=ort_inputs)
    t1 = time.time()
    print(f"{name}: {t1 - t0}")

    return outputs[0]

def performance():
    x = torch.randn(4, 10, 128).numpy()
    run("attn", "./attn.onnx", x)
    run("attn_dynamic_quant", "./attn_dynamic_quant.onnx", x)
    run("attn_static_quant", "./attn_static_quant.onnx", x)

if __name__ == "__main__":
    export_model()
    export_static_quant_model()
    performance()

I write a simple self-attn module, and export to onnx model and dyanmic quant model. then I use onnxruntime tools just like:

python -m onnxruntime.quantization.preprocess --input attn.onnx --output attn_preprocess.onnx

Then I get static quant model. Finally run all models and get the inference time. static quant model takes the most time.

In my understanding, onnxruntime will optimize graph in session initialization stage. It will use function TransformGraph to optimize graph, including fusing DQD nodes. So I print the graph after opimization:

// onnxruntime/core/session/inference_session.cc:initilize()

ORT_RETURN_IF_ERROR_SESSIONID_(TransformGraph(graph, saving_ort_format));
std::cout << "After transform:\n" << graph;

and some matmul nodes:

("/key/MatMul", MatMul, "", 13) : ("input_tensor_DequantizeLinear_Output": tensor(float),"onnx::MatMul_91_DequantizeLinear_Output": tensor(float),) -> ("/key/MatMul_output_0": tensor(float),)

You can find the inputs of matmul node are all fp32 tensor, so I think it is fp32 gemm operation but not int8 gemm.

I have two questions:

  1. Am I quantize the model correctly?
  2. Why onnxruntime do not call int8 gemm?

Here is the whole graph:

After transform:
Inputs:
   "input_tensor": tensor(float)
Nodes:
   ("input_tensor_QuantizeLinear", QuantizeLinear, "", 13) : ("input_tensor": tensor(float),"ortshared_1_0_1_12_token_175": tensor(float),"input_tensor_zero_point": tensor(int8),) -> ("input_tensor_QuantizeLinear_Output": tensor(int8),)
   ("key.bias_DequantizeLinear", DequantizeLinear, "", 13) : ("key.bias_quantized": tensor(int8),"ortshared_1_0_1_13_token_177": tensor(float),"key.bias_zero_point": tensor(int8),) -> ("key.bias_DequantizeLinear_Output": tensor(float),)
   ("onnx::MatMul_90_DequantizeLinear", DequantizeLinear, "", 13) : ("onnx::MatMul_90_quantized": tensor(int8),"ortshared_1_0_1_6_token_168": tensor(float),"onnx::MatMul_90_zero_point": tensor(int8),) -> ("onnx::MatMul_90_DequantizeLinear_Output": tensor(float),)
   ("onnx::MatMul_91_DequantizeLinear", DequantizeLinear, "", 13) : ("onnx::MatMul_91_quantized": tensor(int8),"ortshared_1_0_1_8_token_170": tensor(float),"onnx::MatMul_91_zero_point": tensor(int8),) -> ("onnx::MatMul_91_DequantizeLinear_Output": tensor(float),)
   ("onnx::MatMul_92_DequantizeLinear", DequantizeLinear, "", 13) : ("onnx::MatMul_92_quantized": tensor(int8),"ortshared_1_0_1_1_token_163": tensor(float),"onnx::MatMul_92_zero_point": tensor(int8),) -> ("onnx::MatMul_92_DequantizeLinear_Output": tensor(float),)
   ("query.bias_DequantizeLinear", DequantizeLinear, "", 13) : ("query.bias_quantized": tensor(int8),"ortshared_1_0_1_14_token_178": tensor(float),"query.bias_zero_point": tensor(int8),) -> ("query.bias_DequantizeLinear_Output": tensor(float),)
   ("value.bias_DequantizeLinear", DequantizeLinear, "", 13) : ("value.bias_quantized": tensor(int8),"ortshared_1_0_1_0_token_162": tensor(float),"value.bias_zero_point": tensor(int8),) -> ("value.bias_DequantizeLinear_Output": tensor(float),)
   ("input_tensor_DequantizeLinear", DequantizeLinear, "", 13) : ("input_tensor_QuantizeLinear_Output": tensor(int8),"ortshared_1_0_1_12_token_175": tensor(float),"input_tensor_zero_point": tensor(int8),) -> ("input_tensor_DequantizeLinear_Output": tensor(float),)
   ("/key/MatMul", MatMul, "", 13) : ("input_tensor_DequantizeLinear_Output": tensor(float),"onnx::MatMul_91_DequantizeLinear_Output": tensor(float),) -> ("/key/MatMul_output_0": tensor(float),)
   ("/query/MatMul", MatMul, "", 13) : ("input_tensor_DequantizeLinear_Output/duplicated": tensor(float),"onnx::MatMul_90_DequantizeLinear_Output": tensor(float),) -> ("/query/MatMul_output_0": tensor(float),)
   ("/value/MatMul", MatMul, "", 13) : ("input_tensor_DequantizeLinear_Output/duplicated_token_0": tensor(float),"onnx::MatMul_92_DequantizeLinear_Output": tensor(float),) -> ("/value/MatMul_output_0": tensor(float),)
   ("/key/MatMul_output_0_QuantizeLinear", QuantizeLinear, "", 13) : ("/key/MatMul_output_0": tensor(float),"ortshared_1_0_1_10_token_173": tensor(float),"qdq_s8_to_u8_zp_conversion_token_188": tensor(uint8),) -> ("qdq_s8_to_u8_quant_token_189": tensor(uint8),)
   ("/query/MatMul_output_0_QuantizeLinear", QuantizeLinear, "", 13) : ("/query/MatMul_output_0": tensor(float),"ortshared_1_0_1_15_token_179": tensor(float),"qdq_s8_to_u8_zp_conversion": tensor(uint8),) -> ("qdq_s8_to_u8_quant": tensor(uint8),)
   ("/value/MatMul_output_0_QuantizeLinear", QuantizeLinear, "", 13) : ("/value/MatMul_output_0": tensor(float),"ortshared_1_0_1_4_token_166": tensor(float),"qdq_s8_to_u8_zp_conversion_token_202": tensor(uint8),) -> ("qdq_s8_to_u8_quant_token_203": tensor(uint8),)
   ("/key/MatMul_output_0_DequantizeLinear", DequantizeLinear, "", 13) : ("qdq_s8_to_u8_quant_token_189": tensor(uint8),"ortshared_1_0_1_10_token_173": tensor(float),"qdq_s8_to_u8_zp_conversion_token_188": tensor(uint8),) -> ("/key/MatMul_output_0_DequantizeLinear_Output": tensor(float),)
   ("/query/MatMul_output_0_DequantizeLinear", DequantizeLinear, "", 13) : ("qdq_s8_to_u8_quant": tensor(uint8),"ortshared_1_0_1_15_token_179": tensor(float),"qdq_s8_to_u8_zp_conversion": tensor(uint8),) -> ("/query/MatMul_output_0_DequantizeLinear_Output": tensor(float),)
   ("/value/MatMul_output_0_DequantizeLinear", DequantizeLinear, "", 13) : ("qdq_s8_to_u8_quant_token_203": tensor(uint8),"ortshared_1_0_1_4_token_166": tensor(float),"qdq_s8_to_u8_zp_conversion_token_202": tensor(uint8),) -> ("/value/MatMul_output_0_DequantizeLinear_Output": tensor(float),)
   ("/key/Add", Add, "", 14) : ("key.bias_DequantizeLinear_Output": tensor(float),"/key/MatMul_output_0_DequantizeLinear_Output": tensor(float),) -> ("/key/Add_output_0": tensor(float),)
   ("/query/Add", Add, "", 14) : ("query.bias_DequantizeLinear_Output": tensor(float),"/query/MatMul_output_0_DequantizeLinear_Output": tensor(float),) -> ("/query/Add_output_0": tensor(float),)
   ("/value/Add", Add, "", 14) : ("value.bias_DequantizeLinear_Output": tensor(float),"/value/MatMul_output_0_DequantizeLinear_Output": tensor(float),) -> ("/value/Add_output_0": tensor(float),)
   ("/key/Add_output_0_QuantizeLinear", QuantizeLinear, "", 13) : ("/key/Add_output_0": tensor(float),"ortshared_1_0_1_16_token_180": tensor(float),"qdq_s8_to_u8_zp_conversion_token_190": tensor(uint8),) -> ("qdq_s8_to_u8_quant_token_191": tensor(uint8),)
   ("/query/Add_output_0_QuantizeLinear", QuantizeLinear, "", 13) : ("/query/Add_output_0": tensor(float),"ortshared_1_0_1_5_token_167": tensor(float),"qdq_s8_to_u8_zp_conversion_token_182": tensor(uint8),) -> ("qdq_s8_to_u8_quant_token_183": tensor(uint8),)
   ("/value/Add_output_0_QuantizeLinear", QuantizeLinear, "", 13) : ("/value/Add_output_0": tensor(float),"ortshared_1_0_1_17_token_181": tensor(float),"qdq_s8_to_u8_zp_conversion_token_204": tensor(uint8),) -> ("qdq_s8_to_u8_quant_token_205": tensor(uint8),)
   ("/Reshape_1", Reshape, "", 14) : ("qdq_s8_to_u8_quant_token_191": tensor(uint8),"ortshared_7_1_4_0_token_171": tensor(int64),) -> ("qdq_s8_to_u8_quant_token_193": tensor(uint8),)
   ("/Reshape", Reshape, "", 14) : ("qdq_s8_to_u8_quant_token_183": tensor(uint8),"ortshared_7_1_4_0_token_171": tensor(int64),) -> ("qdq_s8_to_u8_quant_token_185": tensor(uint8),)
   ("/Reshape_2", Reshape, "", 14) : ("qdq_s8_to_u8_quant_token_205": tensor(uint8),"ortshared_7_1_4_0_token_171": tensor(int64),) -> ("qdq_s8_to_u8_quant_token_207": tensor(uint8),)
   ("/Transpose_2", Transpose, "", 13) : ("qdq_s8_to_u8_quant_token_193": tensor(uint8),) -> ("qdq_s8_to_u8_quant_token_195": tensor(uint8),)
   ("/Transpose", Transpose, "", 13) : ("qdq_s8_to_u8_quant_token_185": tensor(uint8),) -> ("qdq_s8_to_u8_quant_token_187": tensor(uint8),)
   ("/Transpose_1", Transpose, "", 13) : ("qdq_s8_to_u8_quant_token_207": tensor(uint8),) -> ("qdq_s8_to_u8_quant_token_209": tensor(uint8),)
   ("/MatMul_output_0_DequantizeLinear", DequantizeLinear, "", 13) : ("qdq_s8_to_u8_quant_token_197": tensor(uint8),"ortshared_1_0_1_7_token_169": tensor(float),"qdq_s8_to_u8_zp_conversion_token_196": tensor(uint8),) -> ("/MatMul_output_0_DequantizeLinear_Output": tensor(float),)
   ("/Div", Div, "", 14) : ("/MatMul_output_0_DequantizeLinear_Output": tensor(float),"ortshared_1_0_1_3_token_165": tensor(float),) -> ("/Div_output_0": tensor(float),)
   ("/Div_output_0_QuantizeLinear", QuantizeLinear, "", 13) : ("/Div_output_0": tensor(float),"ortshared_1_0_1_2_token_164": tensor(float),"qdq_s8_to_u8_zp_conversion_token_198": tensor(uint8),) -> ("qdq_s8_to_u8_quant_token_199": tensor(uint8),)
   ("/Transpose_3", Transpose, "", 13) : ("qdq_s8_to_u8_quant_token_211": tensor(uint8),) -> ("qdq_s8_to_u8_quant_token_213": tensor(uint8),)
   ("/Reshape_3", Reshape, "", 14) : ("qdq_s8_to_u8_quant_token_213": tensor(uint8),"ortshared_7_1_3_0_token_176": tensor(int64),) -> ("qdq_s8_to_u8_quant_token_215": tensor(uint8),)
   ("output_tensor_DequantizeLinear", DequantizeLinear, "", 13) : ("qdq_s8_to_u8_quant_token_215": tensor(uint8),"ortshared_1_0_1_11_token_174": tensor(float),"qdq_s8_to_u8_zp_conversion_token_214": tensor(uint8),) -> ("output_tensor": tensor(float),)
   ("input_tensor_DequantizeLinear/duplicated", DequantizeLinear, "", 13) : ("input_tensor_QuantizeLinear_Output": tensor(int8),"ortshared_1_0_1_12_token_175": tensor(float),"input_tensor_zero_point": tensor(int8),) -> ("input_tensor_DequantizeLinear_Output/duplicated": tensor(float),)
   ("input_tensor_DequantizeLinear/duplicated_token_1", DequantizeLinear, "", 13) : ("input_tensor_QuantizeLinear_Output": tensor(int8),"ortshared_1_0_1_12_token_175": tensor(float),"input_tensor_zero_point": tensor(int8),) -> ("input_tensor_DequantizeLinear_Output/duplicated_token_0": tensor(float),)
   ("/MatMul", QLinearMatMul, "", 10) : ("qdq_s8_to_u8_quant_token_187": tensor(uint8),"ortshared_1_0_1_5_token_167": tensor(float),"qdq_s8_to_u8_zp_conversion_token_186": tensor(uint8),"qdq_s8_to_u8_quant_token_195": tensor(uint8),"ortshared_1_0_1_16_token_180": tensor(float),"qdq_s8_to_u8_zp_conversion_token_194": tensor(uint8),"ortshared_1_0_1_7_token_169": tensor(float),"qdq_s8_to_u8_zp_conversion_token_196": tensor(uint8),) -> ("qdq_s8_to_u8_quant_token_197": tensor(uint8),)
   ("/Softmax", QLinearSoftmax, "com.microsoft", 1) : ("qdq_s8_to_u8_quant_token_199": tensor(uint8),"ortshared_1_0_1_2_token_164": tensor(float),"qdq_s8_to_u8_zp_conversion_token_198": tensor(uint8),"ortshared_1_0_1_9_token_172": tensor(float),"qdq_s8_to_u8_zp_conversion_token_200": tensor(uint8),) -> ("qdq_s8_to_u8_quant_token_201": tensor(uint8),)
   ("/MatMul_1", QLinearMatMul, "", 10) : ("qdq_s8_to_u8_quant_token_201": tensor(uint8),"ortshared_1_0_1_9_token_172": tensor(float),"qdq_s8_to_u8_zp_conversion_token_200": tensor(uint8),"qdq_s8_to_u8_quant_token_209": tensor(uint8),"ortshared_1_0_1_17_token_181": tensor(float),"qdq_s8_to_u8_zp_conversion_token_208": tensor(uint8),"ortshared_1_0_1_11_token_174": tensor(float),"qdq_s8_to_u8_zp_conversion_token_210": tensor(uint8),) -> ("qdq_s8_to_u8_quant_token_211": tensor(uint8),)
Outputs:
   "output_tensor": tensor(float)

To reproduce

Run the python code.

Urgency

No response

Platform

Linux

OS Version

ubuntu

ONNX Runtime Installation

Built from Source

ONNX Runtime Version or Commit ID

main 61a79436e22892bdd91a905389f12e0aee68132e

ONNX Runtime API

Python

Architecture

X64

Execution Provider

Default CPU

Execution Provider Library Version

No response

yufenglee commented 1 year ago

We recommend to use dynamic quantization for transformer models on CPU. If you use static quant, you can limit the op_to_quantize to MatMul only. https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#method-selection https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#transformer-based-models

imzhuhl commented 1 year ago

Thank you for your reply.

"recommend to use dynamic quantization for transformer models on CPU"

What is the reason for this? Is it because of the current support for transformer-based static quantization not good? Or considering the actual situation, the prediction result of dynamic quantification is better.

The reason why I think there is a problem with this static quantization is that the fully connected layer for calculating "query", "key" and "value" is not quantized correctly.