axolotl-ai-cloud / axolotl

Go ahead and axolotl questions
https://axolotl-ai-cloud.github.io/axolotl/
Apache License 2.0
7.53k stars 816 forks source link

QA_LORA #690

Open thistleknot opened 11 months ago

thistleknot commented 11 months ago

⚠️ Please check that this feature request hasn't been suggested before.

🔖 Feature description

Improved quantization training over standard qlora

https://huggingface.co/papers/2309.14717

Recently years have witnessed a rapid development of large language models (LLMs). Despite the strong ability in many language-understanding tasks, the heavy computational burden largely restricts the application of LLMs especially when one needs to deploy them onto edge devices. In this paper, we propose a quantization-aware low-rank adaptation (QA-LoRA) algorithm. The motivation lies in the imbalanced degrees of freedom of quantization and adaptation, and the solution is to use group-wise operators which increase the degree of freedom of quantization meanwhile decreasing that of adaptation. QA-LoRA is easily implemented with a few lines of code, and it equips the original LoRA with two-fold abilities: (i) during fine-tuning, the LLM's weights are quantized (e.g., into INT4) to reduce time and memory usage; (ii) after fine-tuning, the LLM and auxiliary weights are naturally integrated into a quantized model without loss of accuracy. We apply QA-LoRA to the LLaMA and LLaMA2 model families and validate its effectiveness in different fine-tuning datasets and downstream scenarios. Code will be made available at https://github.com/yuhuixu1993/qa-lora.

✔️ Solution

Algorithm 1 QA-LoRA Pseudocode in the PyTorch-like style
# D_in, D_out, D_int: the input, output, and low-rank adaptation dimensions
# L: the quantization group numbers of weights W (D_in // L is the group size)
# s: the coefficient for adaptation; N: the bit width of quantization
QA = nn.AvgPool1d(D in//L)
lora_A = nn.Parameter(torch.empty((D_int, L)))
lora_B = nn.Parameter(torch.empty((D_out, D_int)))
def qalora_forward(x, W, lora_A, lora_B):
W_tilde = pre_quantization(W, alpha, beta)
result = x @ W_tilde
result += (QA(x)*(D in//L)) @ lora_A.transpose(0,1) @ lora_B.transpose(0,1) * s
return result
def pre_quantization(W, alpha, beta):
W_hat = torch.round(W / alpha) + beta
return alpha * (W_hat - beta)
def merge_with_quantization(beta, lora_A, lora_B):
beta_new = beta - s * (lora_B @ lora_A).transpose(0,1) / alpha
return beta_new

❓ Alternatives

No response

📝 Additional Context

No response

Acknowledgements

thistleknot commented 10 months ago

https://www.youtube.com/watch?v=Nmtc_4nIww0

thistleknot commented 10 months ago

https://web.archive.org/web/20231013080809/https://github.com/yuhuixu1993/qa-lora

thistleknot commented 10 months ago

https://github.com/eltociear/qa-lora