vllm-project / vllm

A high-throughput and memory-efficient inference and serving engine for LLMs
https://docs.vllm.ai
Apache License 2.0
21.92k stars 3.09k forks source link

[RFC]: Int8 Activation Quantization #3975

Open tlrmchlsmth opened 2 months ago

tlrmchlsmth commented 2 months ago

Summary

Motivation and Scope

The high-level goal of this RFC is to speed up Prefill by increasing the rate of computation by using int8 tensor cores. We don't anticipate improving decode performance except for very large batch sizes, as inference time in that case is dominated by loading the weights and is already well-served by weight-only quantization.

Int4 activation quantization is out of scope for this RFC, but we are interested in support for it. Successful int4 activation quantization (namely QuaRot) requires more work and more extensive modifications to the model definitions than int8 activation quantization, so it's natural to do this after int8 quantization.

For this RFC, we are focusing on support for Nvidia GPUs, and leaving other systems as out of scope.

Quantization Schemes and Zero Points

We are considering quantization of the form: $$\widehat X = \lfloor \frac{X}{s_x} \rceil + z_x$$ In this case, $X$ is floating point, and $\widehat X$ will be its int8 quantized representation. $s_x$ is the scale or tensor of scales, and $z_x$ is a zero point.

There are several cases to consider, with performance and accuracy tradeoffs in each case.

In light of these considerations, this RFC proposes initially supporting the following cases.

For the weights:

For the activations:

Other cases left as future work, out of scope for this RFC: asymmetric w8a8 weights and asymmetric per-token activations, can be handled by additional $\mathcal O(n^2)$ terms that are be computed during inference. For asymmetric quantized weights where the activation is stored in a higher precision, such as w4a8, the zero points may be handled via a shift after the weights are up-converted to the activation's precision for computation.

Zero Point Correction Terms

This section is a zoom-in on the linear algebra for the zero point correction terms, to further motivate some of the decisions made above on support for asymmetric vs symmetric and per-token vs per-tensor cases.

Suppose we want to compute a quantized GEMM operation $C = AB$, where $A$ is $m \times k$, $B$ is $k \times n$, and $C$ is $m \times n$. In this setting, $A$ is the input activation matrix and $B$ is the weight matrix, known offline. We quantize we quantize the matrices as $C = s_C (\widehat C - z_C J_C)$, $B = s_B (\widehat B - z_B J_B)$, $A = s_A (\widehat A - z_A J_A)$. This is per-tensor quantization where $s_X$ is the scale of matrix $X$, $z_X$ is the zero point of $X$, and $J_X$ is the conformal matrix of all ones. Here we are ignoring any rounding for quantization for simplicity. Let's furthermore assume that $z_C = 0$ and $s_A, s_B, s_C = 1$ just to get them out of the way -- the scales of all matrices and the output's zero point are pretty easy to deal with.

Let's substitute the above equations into $C = AB$ to see how to compute $\widehat C$. $C = AB$ $\widehat C = (\widehat A - z_A J_A) (\widehat B - z_B J_B)$ $\widehat C = \widehat A \widehat B - z_A J_A \widehat B - z_B \widehat A J_B + z_A z_B J_A J_B$

A brief remark on each term:

robertgshaw2-neuralmagic commented 2 months ago

Motivation and Scope

The high-level goal of this RFC is to speed up Prefill by increasing the rate of computation by using int8 tensor cores. We don't anticipate improving decode performance except for very large batch sizes, as inference time in that case is dominated by loading the weights and is already well-served by weight-only quantization.

Note: this will also be extremely important for chunked prefill regime