artidoro / qlora

QLoRA: Efficient Finetuning of Quantized LLMs
https://arxiv.org/abs/2305.14314
MIT License
9.96k stars 820 forks source link

Should base model be dequantized when merging LoRA weights with base model? #254

Open jinyongyoo opened 1 year ago

jinyongyoo commented 1 year ago

Hi, I have a question regarding merging LoRA weights with quantized base model. For cases where we want to merge the LoRA weights back into the original model for inference, we can use merge_and_unload method. However, this obviously isn't possible for the case of quantized base models (as seen in #28).

So a common workaround I've seen is loading the base model (without quantization) and then merging the LoRA weights. But shouldn't this result in a training and inference mismatch since LoRA weights were trained using the quantized model which is different from the base model without quantization (quantization is naturally a noisy process)? I was wondering if such workaround would result in performance degradation of the final model.

Another workaround I can think of is to dequantize the quantized base model and then add the LoRA weight. This would get rid of the training / inference mismatch problem. Has there been any attempt to dequantize the base model and add the LoRA weights?

eugene-yh commented 1 year ago

There is a mathematical hack to dequantize the base model. See my post here: https://github.com/artidoro/qlora/issues/28#issuecomment-1691551954

jinyongyoo commented 1 year ago

Thanks! I ended up using dequantize_4bit method from bnb to dequantize the linear weights, but I think this approach is neat.

ChrisHayduk commented 1 year ago

@jinyongyoo Would you mind sharing the code that you used to dequantize the model? How did you apply dequantize_4bit?

jinyongyoo commented 1 year ago

Not sure if this is 100% correct way to do it.

dequantize_4bit(module.weight.data, quant_state=module.weight.quant_state) where module is instance of bnb.nn.Linear4bit. That should get you the weight that you can use to create torch.nn.Linear

ChrisHayduk commented 1 year ago

@jinyongyoo Awesome, thank you! And you just looped through every module of the model and check if it was of type bnb.nn.Linear4bit and, if it was, you replaced that module with the dequantized version?

jinyongyoo commented 1 year ago

yes