turboderp / exllamav2

A fast inference library for running LLMs locally on modern consumer-class GPUs
MIT License
3.2k stars 236 forks source link

why there isn't a more popular training approach based on exl2. #309

Open laoda513 opened 5 months ago

laoda513 commented 5 months ago

I'm a bit confused about the basis of your training. Currently, the fastest inference solution is exllamav2, but besides using the original gptq model, it's possible to directly train using autogptq and alpaca_lora_4bit, although it seems there's no way to directly train with exl2. From what I see, mainstream training either involves training with the original model and then converting it to gptq/exl2, or using bnb 4bit to load the original model. Then, the lora weights trained and the converted gptq/exl2 model are loaded together for inference with exllamav2. However, loading with bnb4bit/gptq4bit should result in the more loss of original weights than exl2 5/6bits. If lora weights trained on these weights are loaded together with the more accurate exl2, there should be some discrepancies. So, I'm quite curious why there isn't a more popular training approach based on exl2.

turboderp commented 5 months ago

It would be possible to create a drop-in replacement linear layer module to let you use EXL2 tensors in PyTorch and HF Transformers. If you added support for back propagation as well, you'd be able to train with the quantized weights frozen (LoRAs, soft prompts etc.). But the benefit over BnB 8-bit wouldn't be huge. If you're on a setup where you can fit 53 GB of weights (along with however many GB for activations and gradients), but you couldn't fit 70 GB, then I guess there's that. But any difference in inference speed wouldn't translate to training/finetuning.

laoda513 commented 5 months ago

It would be possible to create a drop-in replacement linear layer module to let you use EXL2 tensors in PyTorch and HF Transformers. If you added support for back propagation as well, you'd be able to train with the quantized weights frozen (LoRAs, soft prompts etc.). But the benefit over BnB 8-bit wouldn't be huge. If you're on a setup where you can fit 53 GB of weights (along with however many GB for activations and gradients), but you couldn't fit 70 GB, then I guess there's that. But any difference in inference speed wouldn't translate to training/finetuning.

"Sorry, I'm not quite clear on this. Are you suggesting that in terms of training effectiveness, there wouldn't be a significant difference between using EXL2 6bit and BnB 8-bit, even though the latter consumes more memory (from 53GB to 70GB)

turboderp commented 5 months ago

The 53 GB vs 70 GB is just the difference in the weights. The 2-bit difference would be next to, say, 48 bits per parameter for gradients, and then you have large batches on top probably, maybe huge attention matrices if you're doing long-range finetuning.. that sort of thing.

laoda513 commented 5 months ago

Indeed. That's why I'm puzzled by the lack of attention to training based on the EXL2 format. During training, the required memory can increase significantly. For instance, using a 4-bit quantized training method like GPTQ, and applying gradient checkpoint or LoRA, it's possible to train the basic LLaMa2 70B model with 8x 3090 GPUs (I haven't tried long contexts yet). However, if training is based on BNB 8-bit, it cause oom. That's why I've always train with GPTQ model.

When I saw that EXL2 supports LoRA, but there isn't a specific training method based on EXL2 format, I was curious: what methods do people use to train LoRA? Is it Transformer+BNB, GPTQ, direct training with original weights, or is there some other, better method?

I'm not an algorithm engineer, but based on my understanding of LoRA, its basic algorithm is based on W+deltaW. Even if deltaW is trained very well, different quantizations will inevitably cause some variations in W, and thus in W+deltaW. Could it be that the level of these differences doesn't lead to significant degradation in the combined model's performance, and that's why there hasn't been a specific focus on developing training methods based on the EXL2 format?

Additionally, if I want the trained model to be inferred using EXL2, which training approach would you recommend? This is considering that I don't have sufficient memory to train with the original weights.

grimulkan commented 4 months ago

I do think the 4-8 bit range can have a big impact on the training VRAM and therefore what is even accessible, so it would be worth it IMO. QLora doesn't really offer variable precision. It's not the overall size but rather hitting the ideal breakpoints, and a few GB per card can help there.

@turboderp I'm also curious how hard it would be to implement native auto-differentiation and backpropagation in Exllama, at least for LORA layers (with EXL2 quantized model weights frozen), and for decoder-only models like Llama. The math doesn't sound more complicated than what you're already doing I think (heck, the quantization code with Cholesky and inverses sounds way more numerically nasty than auto-diff). If the focus is on Llama-style architectures, you wouldn't need to support arbitrary compute graphs like pytorch auto-diff does, and it can be specialized to LORA layers. Guess you'd need a simple optimizer too (or plain SGD) to get a training loop going.

It'd be nice to have a training loop independent of all the crap and bloat that transformers + PEFT + Qlora/bnb etc. introduces, not to mention wrapper packages like deepspeed. For instance, there is no pipeline parallel training in conjunction with quantization anywhere, but it could be implemented more easily if the building blocks were present in a more low-level implementation like Exllama, without mucking with a nightmare of dependencies. I think a lot of bloat exists because a lot of the packages are highly generalized, and Llama (and Yi, Mistral, etc.) training options therefore suffer.

EDIT: What you mentioned earlier:

It would be possible to create a drop-in replacement linear layer module to let you use EXL2 tensors in PyTorch and HF Transformers. If you added support for back propagation as well, you'd be able to train with the quantized weights frozen (LoRAs, soft prompts etc.).

Maybe that also helps with this if it existed, and one could write a lower-level pytorch training loop if the EXL2 frozen portion propagated the gradients. Or... maybe that's just a totally different project. It's just nice that Exllama already does a number of things independent of transformers.