microsoft / TransformerCompression

For releasing code related to compression methods for transformers, accompanying our publications
MIT License
354 stars 31 forks source link

QuaRot: cascade into quarot main #143

Closed nailimixaM closed 2 months ago

nailimixaM commented 4 months ago

Tested PPL for A16W4 with and without rotation:

Llama-2 7B

Llama-2 13B

Key code differences to fake_quant: gpu implementation of grid search for weight clipping -> halves time for applying weight RTN in 7 (30s) and 13B (1min) models compared to spcl fake_quant, however untested on 70B and may cause GPU OOM. The previous commit has the equivalent fake_quant more memory efficient method. However, what is very strange is that this takes 10mins on 7B whereas this takes only 1min in spcl repo.

@sashkboos could you have a look at e.g. this PR/branch at the "add weight clipping" commit and see if you can spot anything in the rtn logic that could be causing these accuracy differences? The other possibility is that there is a bug in the "no rotation" logic, and that we're actually doing things to the network that we're not meant to be doing, causing big diffs here.

sashkboos commented 4 months ago

Thanks @nailimixaM

Before diving into the performance issue, I would like to understand the difference between these numbers and the results in the paper. Do you follow the same arguments (including clipping) as the paper experiments?

nailimixaM commented 4 months ago

Thanks @nailimixaM

Before diving into the performance issue, I would like to understand the difference between these numbers and the results in the paper. Do you follow the same arguments (including clipping) as the paper experiments?

Yes -> symmetric, perchannel and weight clipping all set to true.

nailimixaM commented 4 months ago

With bugfix in clamping:

Llama-2 7B

Llama-2 13B

nailimixaM commented 4 months ago

With bugfix in clamping:

Llama-2 7B

  • with rotation: 6.83 vs 6.76 in paper
  • without rotation: 8.07 vs 6.99 in paper

Llama-2 13B

  • with rotation: 5.59 vs 5.48 in paper
  • without rotation: 13.00 vs 6.32 in paper [still very anomalous]

I suspect the discrepancy in the no-rotation ppls is the fact that I was evaluating quantization on a fused-but-not-rotated model here, whereas in fake_quant this is done on the raw Llama model. The QuaRot MLP and Attn modules I have written assume that the model has fused layernorms. To do no rotation quantization emulation would require some work to either add new modules or to modify the QuaRot ones with a switch specifying whether layernorms have been fused or not.

sashkboos commented 4 months ago

I suspect the discrepancy in the no-rotation ppls is the fact that I was evaluating quantization on a fused-but-not-rotated model here, whereas in fake_quant this is done on the raw Llama model. The QuaRot MLP and Attn modules I have written assume that the model has fused layernorms. To do no rotation quantization emulation would require some work to either add new modules or to modify the QuaRot ones with a switch specifying whether layernorms have been fused or not.

Makes sense. So that's why you got higher ppl values as your fused model has worse distribution which makes sense to me :-)

nailimixaM commented 4 months ago

Fixed without rotation. Results for A16W4 are now:

Llama-2 7B

Llama-2 13B

The ~0.1ppl discrepancies with rotation have me a bit stumped but at this point I think we can crack on with KV-cache and activation quantization to get end-to-end RTN working. If large discrepancies appear in these it could be worth searching further.

@jameshensman @pashminacameron @msdmkats I'm setting this ready to review now, see above for latest ppls.