johnsmith0031 / alpaca_lora_4bit

MIT License
533 stars 84 forks source link

Finetuning 2-bit Quantized Models #115

Open kuleshov opened 1 year ago

kuleshov commented 1 year ago

Hey @johnsmith0031, thank you for this great repo! I was wondering if you tried implementing the backward pass for 2-bit or 3-bit quantized models?

I would really like to try it as an experiment. If you have any existing work on 2-bit or 3-bit autograd, I would love to contribute to it and submit a PR to this repo. Or as a first step, I could run it and share experimental results.

johnsmith0031 commented 1 year ago

I think we'd better wait for someone succeed in quantizing model in 2-bit without much performance loss (like QLoRA, but not sure aboute it's performance)

kuleshov commented 1 year ago

So, actually, we know how to do it in two bits! We're a team of researchers at Cornell and we have working prototypes of two-bit compression that achieves good perplexity at inference time. We would like to now explore finetuning and your codebase is very helpful to us.

There are actually two ways of doing it: one is a new algorithm which we are going to soon put on the ArXiv; but even the vanilla GPTQ model performs somewhat acceptably in 2 bits on the largest LLMs (check out Table 7 in their paper).

Would you be interested in talking more about these experiments and exchanging code or ideas?

johnsmith0031 commented 1 year ago

Thanks for showing interest in my code! I add 2bit reconstruction functions to the cuda kernel in another branch. I think you can adjust the code if needed. https://github.com/johnsmith0031/alpaca_lora_4bit/tree/2bit_support

NicoNico6 commented 1 year ago

Hello, I have been working on similar topics (2 bits and lower). However, I have noticed that the PPL calculated using the current main branch is consistently higher than the original GPTQ-Triton. I'm interested in understanding the reasons behind this difference. Could you please provide some insights into whether this could be due to version alignment issues or other factors? I would appreciate any ideas or suggestions to further investigate this matter. Thank you, @johnsmith0031

kuleshov commented 1 year ago

@johnsmith0031 Thank you! The two-bit extension makes a lot of sense. I'm working on modifying it with custom groupsizes and I can return the code as PR if you're interested.

The part that has been giving me more trouble was the 3-bit one. I'm not sure I understand the implementation well enough to figure out how to unpack and return the weight matrix in CUDA. If you happen to have played with that and you have code you could share, that would be helpful, but no worries if not.

johnsmith0031 commented 1 year ago

Hello, I have been working on similar topics (2 bits and lower). However, I have noticed that the PPL calculated using the current main branch is consistently higher than the original GPTQ-Triton. I'm interested in understanding the reasons behind this difference. Could you please provide some insights into whether this could be due to version alignment issues or other factors? I would appreciate any ideas or suggestions to further investigate this matter. Thank you, @johnsmith0031

Not sure about the reason, but maybe it is because of the difference of matrix multiplication. I first reconstruct the matrix to float16 and use torch.matmul for matrix multiplication (but only for that case when batch_size * seq_len > 8), which is different from both original cuda and triton kernel.

johnsmith0031 commented 1 year ago

@johnsmith0031 Thank you! The two-bit extension makes a lot of sense. I'm working on modifying it with custom groupsizes and I can return the code as PR if you're interested.

The part that has been giving me more trouble was the 3-bit one. I'm not sure I understand the implementation well enough to figure out how to unpack and return the weight matrix in CUDA. If you happen to have played with that and you have code you could share, that would be helpful, but no worries if not.

Yes 3-bit seems to be more complicated than 4-bit and 2-bit...