pytorch / ao

Create and integrate custom data types, layouts and kernels with up to 2x speedups and 65% less VRAM for inference and training
BSD 3-Clause "New" or "Revised" License
361 stars 52 forks source link

Understanding 8da4w #430

Open DzAvril opened 1 week ago

DzAvril commented 1 week ago

Hi there,

I'm new to quantization. From my understanding, "8da4w" means that the weights are pre-quantized to 4 bits, and the activations are quantized to 8 bits at runtime. Following this, the GEMM (General Matrix Multiply) operation between weights and activations is computed in the int8 data type. Do I have this correct?

However, I'm confused by the code for Int8DynActInt4WeightQuantizer. The forward method of Int8DynActInt4WeightLinear calls a method named per_token_dynamic_quant, which can be found here. In this method, the input is first quantized to int8 and then immediately converted back to its original data type without further processing. I don't understand the purpose of this function. Furthermore, I have launched a program using Int8DynActInt4WeightQuantizer and observed the data types of x and w_dq in the method linear_forward_8da4w, which can be found here, they both are float32. This seems to contradict my understanding of the computations involved in '8da4w'.

I realize that I'm likely missing some fundamental aspects of dynamic quantization. Could anyone kindly clarify this process for me?

Thank you!

supriyar commented 1 week ago

Following this, the GEMM (General Matrix Multiply) operation between weights and activations is computed in the int8 data type.

this probably depends on the specific backend. For 8da4w we've tested it to work with ExecuTorch runtime (XNNPack backend) which I believe does the computation in the int bitwidths directly (8-bit act x 4-bit weight)

@jerryzh168 can probably help confirm this and help answer the other questions.

jerryzh168 commented 1 week ago

It's true that we will need to use integer compute to speed things up, that's what we are doing in our int8_dynamic_activation_int8_weight (running on CUDA) API: https://github.com/pytorch/ao/tree/main/torchao/quantization#a8w8-dynamic-quantization

But specifically for 8da4w, we don't expect immediate speed up after quantization in server since that is targeting to be used in ExecuTorch (https://github.com/pytorch/ao/tree/main/torchao/quantization#to-be-deprecated-a8w8-dynamic-quantization) and the requirement there is that we produce a representation for quantized model so that it can be matched and lowered to a specific library (e.g. xnnpack). Here is a bit more context on the reasoning behind producing a pattern for further downstream consumption: https://github.com/pytorch/rfcs/blob/master/RFC-0019-Extending-PyTorch-Quantization-to-Custom-Backends.md