ggerganov / llama.cpp

LLM inference in C/C++
MIT License
67.91k stars 9.74k forks source link

Support QuaRot quantization scheme #6444

Open EwoutH opened 7 months ago

EwoutH commented 7 months ago

A new, interesting quantization scheme was published, which not only reduces memory consumption (like current quantization schemes), but als reduces computations.

QuaRot: Outlier-Free 4-Bit Inference in Rotated LLMs We introduce QuaRot, a new Quantization scheme based on Rotations, which is able to quantize LLMs end-to-end, including all weights, activations, and KV cache in 4 bits. QuaRot rotates LLMs in a way that removes outliers from the hidden state without changing the output, making quantization easier. This computational invariance is applied to the hidden state (residual) of the LLM, as well as to the activations of the feed-forward components, aspects of the attention mechanism and to the KV cache. The result is a quantized model where all matrix multiplications are performed in 4-bits, without any channels identified for retention in higher precision. Our quantized LLaMa2-70B model has losses of at most 0.29 WikiText-2 perplexity and retains 99% of the zero-shot performance. Code is available at: this https URL.

I think it would be interesting to see if this technique, or parts of it, could be adopted in llama.cpp, to speed up inference of quantized models.

ggerganov commented 7 months ago

It's an interesting approach that we should explore. As far as I understood, the model weights are pre-processed (rotated) and then the inference is augmented with extra (two?) operations to restore the effects from the rotation. We can start with implementing the latter and try to use existing QuaRot models to evaluate

PinkysBrain commented 7 months ago

Similar to "Training Transformers with 4-bit Integers", except that only used Hadamard. Hadamard alone might be enough, QuaRot did do an ablation for Q alone, but not Hadamard alone.

sashkboos commented 7 months ago

@ggerganov

Thanks for your interest in our work. I am the main author of QuaRot.

I would be happy to discuss/plan for this and help to integrate it into the repo. I think the general steps of QuaRot are:

  1. First you need to convert your model into an RMNorm model. This means that you fuse all other operations except normalization (divide by norm) into the previous and next linear weights and use this class as the normalization module.
  2. Then, you need to fuse the Hadamard transformations into the weights. We use a single random Hadamard matrix for this which is generated by this function. Note that this is not a Walsh-Hadamard matrix (where you have a fast kernel for that) and it is a randomized hadamard matrix. However, as this is fused into the weight matrices, you will not add any overhead (so you can fuse any other orthogonal matrix you like). You fuse the first one into the embedding module and the inverse of the last one into the lm_head. Note that you should use a single matrix for the whole network (otherwise, you have to rotate the shortcuts as it is the case in the SliceGPT).
  3. With the above steps, the inputs of all Attention, and MLPs will be ready for quantization, but not the output of them which are out_proj and down_proj (as you have non-linearity in the middle. For the down_proj layer, we use a single online Hadamard transformation. If this is power-of-two, we use Dao's great repo for that transformation. Otherwise (which is the case in LLaMa-2 models), we have to use the Kronecker construction of Hadamard transformation (check the paper for details). This part is implemented here. We should thanks QuIP# paper for their great work that we inspired from here.
  4. Finally, for the out_proj layer, we fuse the randomized Hadamard transformation into the v_proj (code is here). However, we do it over the heads as the MatMul will be done over each head in the attention. We apply an online Hadamard transformation across the number of heads just before out_proj to construct the Kronecker construction again (again, if the number of heads will be power-of-two, you can do it fast using Dao's kernel, otherwise, you should multiply the Hadamard matrix of the size of num_head online.

With the above step, you can quantize the model easily. optionally, you can apply another rotations for quantizing Keys in the attention module (please check the Method section in our paper).

Please let me know if you need any help/support from our side through my email: saleh.ashkboos@inf.ethz.ch

sorasoras commented 7 months ago

@JohannesGaessler I remember you "have int8 tensor core matrix multiplication" https://github.com/ggerganov/llama.cpp/pull/4801 It might be useful for this quantization scheme.

JohannesGaessler commented 7 months ago

Before I invest time into a specialized kernel for a given quantization method I would like to first see evidence that it's better than those methods currently on master.

sorasoras commented 7 months ago

Before I invest time into a specialized kernel for a given quantization method I would like to first see evidence that it's better than those methods currently on master.

I don't think we have any quant that could take advantage of INT8/INT4 tensor cores yet. it build quants that specifically run at INT4/INT6/INT8 computationally. image

we got quite a bit speed boost

image

QuaRot: a method which uses Hadamard matrices to eliminate outliers in the activations and KV cache of pre-trained LLMs, enabling end-to-end 4-bit quantization for the first time (to the best of our knowledge). Quantizing LLAMA2-70B to 4 bits with QuaRot maintains 99% of the downstream task performance of the FP16 baseline, with a 2.16x speedup on RTX 3090 GPUs during the prefill stage (and up to 3.39x memory saving during the decoding stage). Quantizing all LLAMA-2 models to 6- and 8-bits is lossless.

Anyway, I just want to let you know.

ddh0 commented 7 months ago

Quantizing all LLAMA-2 models to 6- and 8-bits is lossless

😮‍💨

PinkysBrain commented 6 months ago

QuIP# seems interesting in and of itself. Weights are dequantised before use so it can't use 4bit math, which is the main attraction of QuaRot, but it does perform better.

Also it doesn't try to fold and preserve the rotated space whenever possible. For inference it's just, transform->matmul->detransform. QuaRot is really elegant, but hard to follow in comparison. The QuIP# way will be easier drop into the existing code.

github-actions[bot] commented 4 months ago

This issue was closed because it has been inactive for 14 days since being marked as stale.

EwoutH commented 4 months ago

Can we reopen this?

github-actions[bot] commented 3 months ago

This issue was closed because it has been inactive for 14 days since being marked as stale.

liuzechun commented 2 months ago

Another relevant work you might be interested in.

SpinQuant: LLM Quantization with Learned Rotations (https://arxiv.org/pdf/2405.16406)

In this work, we identify a collection of applicable rotation parameterizations that lead to identical outputs in full-precision Transformer architectures, and find that some random rotations lead to much better quantization than others, with an up to 13 points difference in downstream zeroshot reasoning performance. As a result, we propose SpinQuant that optimizes (or learns) the rotation matrices with Cayley optimization on a small validation set. With 4-bit quantization of weight, activation, and KV-cache, SpinQuant narrows the accuracy gap on zero-shot reasoning tasks with full precision to merely 2.9 points on the LLaMA-2 7B model, surpassing LLM-QAT by 19.1 points and SmoothQuant by 25.0 points. SpinQuant also outperforms concurrent work QuaRot, which applies random rotations to remove outliers. In particular, for LLaMA-2 7B/LLaMA-3 8B models that are hard to quantize, SpinQuant reduces the gap to full precision by 30.2%/34.1% relative to QuaRot.

This work learns the rotation matrices and achieves even better results than QuaRot with fewer online Hadamard rotation. 😝

If you're interested we can chat and we can provide support if needed!

image
github-actions[bot] commented 4 weeks ago

This issue was closed because it has been inactive for 14 days since being marked as stale.

AndrewNLauder commented 2 weeks ago

Is llama.cpp planning/able to support SpinQuant? According to meta, SpinQuant + QLora are enabling really great things, and it would be great to not have to use meta's llama-stack to take advantage of them. https://ai.meta.com/blog/meta-llama-quantized-lightweight-models/