Open tlrmchlsmth opened 2 months ago
Motivation and Scope
The high-level goal of this RFC is to speed up Prefill by increasing the rate of computation by using int8 tensor cores. We don't anticipate improving decode performance except for very large batch sizes, as inference time in that case is dominated by loading the weights and is already well-served by weight-only quantization.
Note: this will also be extremely important for chunked prefill regime
Summary
Motivation and Scope
The high-level goal of this RFC is to speed up Prefill by increasing the rate of computation by using int8 tensor cores. We don't anticipate improving decode performance except for very large batch sizes, as inference time in that case is dominated by loading the weights and is already well-served by weight-only quantization.
Int4 activation quantization is out of scope for this RFC, but we are interested in support for it. Successful int4 activation quantization (namely QuaRot) requires more work and more extensive modifications to the model definitions than int8 activation quantization, so it's natural to do this after int8 quantization.
For this RFC, we are focusing on support for Nvidia GPUs, and leaving other systems as out of scope.
Quantization Schemes and Zero Points
We are considering quantization of the form: $$\widehat X = \lfloor \frac{X}{s_x} \rceil + z_x$$ In this case, $X$ is floating point, and $\widehat X$ will be its int8 quantized representation. $s_x$ is the scale or tensor of scales, and $z_x$ is a zero point.
There are several cases to consider, with performance and accuracy tradeoffs in each case.
In light of these considerations, this RFC proposes initially supporting the following cases.
For the weights:
For the activations:
Other cases left as future work, out of scope for this RFC: asymmetric w8a8 weights and asymmetric per-token activations, can be handled by additional $\mathcal O(n^2)$ terms that are be computed during inference. For asymmetric quantized weights where the activation is stored in a higher precision, such as w4a8, the zero points may be handled via a shift after the weights are up-converted to the activation's precision for computation.
Zero Point Correction Terms
This section is a zoom-in on the linear algebra for the zero point correction terms, to further motivate some of the decisions made above on support for asymmetric vs symmetric and per-token vs per-tensor cases.
Suppose we want to compute a quantized GEMM operation $C = AB$, where $A$ is $m \times k$, $B$ is $k \times n$, and $C$ is $m \times n$. In this setting, $A$ is the input activation matrix and $B$ is the weight matrix, known offline. We quantize we quantize the matrices as $C = s_C (\widehat C - z_C J_C)$, $B = s_B (\widehat B - z_B J_B)$, $A = s_A (\widehat A - z_A J_A)$. This is per-tensor quantization where $s_X$ is the scale of matrix $X$, $z_X$ is the zero point of $X$, and $J_X$ is the conformal matrix of all ones. Here we are ignoring any rounding for quantization for simplicity. Let's furthermore assume that $z_C = 0$ and $s_A, s_B, s_C = 1$ just to get them out of the way -- the scales of all matrices and the output's zero point are pretty easy to deal with.
Let's substitute the above equations into $C = AB$ to see how to compute $\widehat C$. $C = AB$ $\widehat C = (\widehat A - z_A J_A) (\widehat B - z_B J_B)$ $\widehat C = \widehat A \widehat B - z_A J_A \widehat B - z_B \widehat A J_B + z_A z_B J_A J_B$
A brief remark on each term:
$\widehat A \widehat B$: will be computed by our quantized GEMM kernel.
$z_A z_B J_A J_B$: If per-tensor quantization is used, every value of $z_A z_B J_A J_B$, is the same and depends only on $k$ and the zero points of $A$ and $B$.
$z_A J_A \widehat B$: A few remarks on this one.
$z_B \widehat A J_B$: This term depends on the activation matrix, so must be computed at runtime if asymmetric weight quantization is used.