NVIDIA / TensorRT-LLM

TensorRT-LLM provides users with an easy-to-use Python API to define Large Language Models (LLMs) and build TensorRT engines that contain state-of-the-art optimizations to perform inference efficiently on NVIDIA GPUs. TensorRT-LLM also contains components to create Python and C++ runtimes that execute those TensorRT engines.
https://nvidia.github.io/TensorRT-LLM
Apache License 2.0
8.16k stars 903 forks source link

fp32 LoRA support for Llama #1910

Open pankajroark opened 2 months ago

pankajroark commented 2 months ago

Currently, TensorRT-LLM requires that LoRA weights dtype match the base model dtype. The check is here: https://github.com/NVIDIA/TensorRT-LLM/blob/9dbc5b38baba399c5517685ecc5b66f57a177a4c/cpp/tensorrt_llm/runtime/loraUtils.cpp#L66

One way around is to quantize LoRA before passing to TensorRT-LLM. But this results in unacceptably lower quality. LoRA matrices get multiplied, quantizing them from fp32 to fp16 beforehand multiples the quality loss. Whereas quantizing after the multiplication is much better. We experimented with merging the LoRA weights into the basemodel and we didn't see any quality degradation there, because the LoRA merge happens at fp32 there.

from transformers import AutoModelForCausalLM
from peft import PeftModel
bm = AutoModelForCausalLM.from_pretrained("./llama3")
pm = PeftModel.from_pretrained(bm, "./fp32_lora")
pm.merge_adapter()

It would be best if TensorRT-LLM could accept fp32 LoRA, multiply LoRA low rank matrices in fp32 and quantize the multiplied matrix to confirm to fp16 of base model. This way the quantization loss will be much lower.

yuxianq commented 2 months ago

@pankajroark The engine itself supports fp32 LoRA, so this runtime limitation is unnecessary. I can help to add support for fp32 LoRA, can you provide your model + LoRA checkpoint + commands (use similar open-source alternative if your model is private) to me for validation?

pankajroark commented 2 months ago

Thanks, great to know that the engine supports fp32 LoRA. The model is indeed private, let me provide details shortly in the oss alternative.

pankajroark commented 1 month ago

Would appreciate any updates on this issue. thx

yuxianq commented 1 month ago

@pankajroark I cannot access the fp32 LoRA link you provided, it may be a private repo. After some investigation, I find that the lora plugin only supports fp32 base model + fp32 lora now, so simply removing the runtime limitation is not enough to run fp16 base model + fp32 lora. We have to update lora plugin to support it, which makes it a feature request instead of a quick bugfix. We will try to allocate engineer bandwidth for it, but cannot promise to finish in v0.12.

pankajroark commented 1 month ago

Thanks for the update.

github-actions[bot] commented 3 weeks ago

This issue is stale because it has been open 30 days with no activity. Remove stale label or comment or this will be closed in 15 days."