turboderp / exllamav2

A fast inference library for running LLMs locally on modern consumer-class GPUs
MIT License
3.22k stars 238 forks source link

How to not quantize lm_head? #269

Closed fahadh4ilyas closed 6 months ago

fahadh4ilyas commented 6 months ago

Is there a way to add flag to not quantize the lm_head when using convert.py? The reasons is I might want to retrain the embedding and lm_head part when I train with lora and I could just replace lm_head and embedding weight when applied the lora.

turboderp commented 6 months ago

There's no option that currently supports that, but the implementation does support mixing quantized and unquantized tensors, and the decision on which to use is based on which tensors are present when loading the model. So if you replace the lm_head.* tensors in the final .safetensors file(s) with the single lm_head.weight tensor from the original model, everything should still work.

I've recently become aware of LoRAs that include full replacement embedding and head layers along with the low-rank adapter layers, and I'm considering ways to add support for that. Thing is it conflicts with the existing option to stack multiple adapters and/or switch between them at inference time.

It looks like at one point I had a -h 16 option in the works (which would just store the FP16 head layer instead of quantizing), and I suppose I could finish that, but it wouldn't really with the LoRA issue.

Now, after loading a model with ExLlamaV2.load(), you should also be able to replace the head layer (whether the copy currently in memory is quantized or not) with:

head_layer = model.modules_dict["lm_head"]
head_layer.unload()
head_layer.load(w = nn.Parameter(new_lm_head_tensor))

I haven't tested it for this use case, and it looks like I overlooked adding the same override for the embedding module, but I guess I'll add that when I have a chance later.

fahadh4ilyas commented 6 months ago

So, I just need to replace the lm_head inside the safetensors file to non-quantized weight? Okay, I will try it.