mit-han-lab / smoothquant

[ICML 2023] SmoothQuant: Accurate and Efficient Post-Training Quantization for Large Language Models
https://arxiv.org/abs/2211.10438
MIT License
1.1k stars 127 forks source link

Question: why not need explicit scaling for activation X #79

Open ghost opened 3 months ago

ghost commented 3 months ago

Hi All,

I've looked at the paper and source code, but got some questions. Based on the paper, we need to scale both activation X and weights W in the way of

Y = (X diag(s)^-1) * (diag(s) W)

However, in "smooth.py" file, I could only see that weights W are scaled by multiplying with "diag(s)"

    for fc in fcs:
        fc.weight.mul_(scales.view(1, -1))
  1. I couldn't find where the activation X is scaled, i.e., missing the diag(s)^-1 factor. I then had to assume X was scaled in reference time. But then in the examples notebooks the models are scaled and then directly put into reference just like usual. Question is, where is X scaled?

  2. In the paper, it says Considering input X is usually produced from previous linear operations (e.g., linear layers, layer norms, etc.), we can easily fuse the smoothing factor into previous layers’ parameters offline, which doe not incur kernel call overhead from an extra scaling. For some other cases, when the input is from a residual add, we can add an extra scaling to the residual branch similar to Wei et al. (2022). I don't find which part of the code handles fusing scaling X into scaling W from previous layer, other than cancel the whole scaling procedure.

  3. why do we scale layerNorm just like X? Is it related to previous question?

    ln.weight.div_(scales)
    ln.bias.div_(scales)
  4. why for Llama like models, we don't need ln.bias.div_(scales)?

  5. which part of the code handles scalings for residual connect layers?

ghost commented 3 months ago

@Guangxuan-Xiao can any of the authors explain this to me?

ghost commented 3 months ago

I think I've got the idea. The scaling for X is absorbed into layer norm. Now an extra question, the fake_quant file does NOT actually change the data type to int8, right? If right, how can we actually change the model to int8 data type?