NVIDIA / TensorRT-LLM

TensorRT-LLM provides users with an easy-to-use Python API to define Large Language Models (LLMs) and build TensorRT engines that contain state-of-the-art optimizations to perform inference efficiently on NVIDIA GPUs. TensorRT-LLM also contains components to create Python and C++ runtimes that execute those TensorRT engines.
https://nvidia.github.io/TensorRT-LLM
Apache License 2.0
8.16k stars 902 forks source link

Do you have plans to support Qwen Lora? #1262

Closed h6353115 closed 3 months ago

h6353115 commented 6 months ago

Do you have plans to support Qwen Lora?

byshiue commented 6 months ago

We don't have such plan so far. If you are interested, you could create a feature request to help tracking your request, and we will consider its priority in our roadmap.

h6353115 commented 6 months ago

Thanks

h6353115 commented 5 months ago

We don't have such plan so far. If you are interested, you could create a feature request to help tracking your request, and we will consider its priority in our roadmap.

I have some additional questions here; I'd really appreciate it if you could help me resolve them.

If one were to implement Qwen Lora themselves, what tasks would need to be undertaken? How challenging is it? Which version, Qwen1.0 Lora or Qwen1.5 Lora, is generally easier to support?

yuxianq commented 5 months ago

It is easy to add lora support to Qwen. You can apply the following patch to tensorrt_llm/models/qwen/model.py:

diff --git a/tensorrt_llm/models/qwen/model.py b/tensorrt_llm/models/qwen/model.py
index 4fa236443..ff6b137f4 100644
--- a/tensorrt_llm/models/qwen/model.py
+++ b/tensorrt_llm/models/qwen/model.py
@@ -15,6 +15,8 @@

 from typing import Optional

+from tensorrt_llm.lora_manager import LoraBuildConfig, use_lora
+
 from ..._utils import pad_vocab_size
 from ...functional import Tensor, recv, send
 from ...layers import (Attention, AttentionMaskType, ColumnLinear, Embedding,
@@ -76,6 +78,7 @@ class QWenDecoderLayer(Module):
         use_cache=False,
         kv_cache_params=None,
         attention_params=None,
+        lora_layer_params=None,
     ):
         residual = hidden_states
         hidden_states = self.input_layernorm(hidden_states)
@@ -85,6 +88,7 @@ class QWenDecoderLayer(Module):
             use_cache=use_cache,
             kv_cache_params=kv_cache_params,
             attention_params=attention_params,
+            lora_layer_params=lora_layer_params,
         )
         if use_cache:
             attention_output, presents = attention_output
@@ -95,7 +99,7 @@ class QWenDecoderLayer(Module):

         hidden_states = self.post_layernorm(hidden_states)

-        hidden_states = self.mlp(hidden_states)
+        hidden_states = self.mlp(hidden_states, lora_layer_params=lora_layer_params)

         hidden_states = residual + hidden_states
         if use_cache:
@@ -130,7 +134,8 @@ class QWenModel(Module):
                 hidden_states=None,
                 prompt_embedding_table: Optional[Tensor] = None,
                 prompt_tasks: Optional[Tensor] = None,
-                prompt_vocab_size: Optional[Tensor] = None):
+                prompt_vocab_size: Optional[Tensor] = None,
+                lora_params=None):

         ptuning_args = [
             prompt_embedding_table, prompt_tasks, prompt_vocab_size
@@ -145,7 +150,8 @@ class QWenModel(Module):
                                             use_cache=use_cache,
                                             attention_mask=attention_mask,
                                             kv_cache_params=kv_cache_params,
-                                            attention_params=attention_params)
+                                            attention_params=attention_params,
+                                            lora_params=lora_params)

         if use_cache:
             hidden_states, presents = hidden_states
@@ -185,3 +191,6 @@ class QWenForCausalLM(DecoderModelForCausalLM):
     def check_config(self, config):
         config.set_if_not_exist('rotary_base', 10000.0)
         config.set_if_not_exist('rotary_scaling', None)
+
+    def use_lora(self, lora_config: LoraBuildConfig):
+        use_lora(self, lora_config)

Assume that you have a lora checkpoint with the same format with https://huggingface.co/hfl/chinese-llama-2-lora-13b in tmp/qwen-lora-14b, then you can use lora with qwen (see lora document in examples/llama/README.md and qwen document in examples/qwen/README.md):

cd examples/qwen
pip install -r requirements.txt
git clone https://huggingface.co/Qwen/Qwen-14B-Chat  tmp/Qwen/14B

python convert_checkpoint.py --model_dir tmp/Qwen/14B \
                         --output_dir ./tmp/Qwen/14B/trt_ckpt/lora/fp16/1-gpu \
                         --dtype float16

trtllm-build --checkpoint_dir ./tmp/Qwen/14B/trt_ckpt/lora/fp16/1-gpu \
            --output_dir ./tmp/Qwen/14B/trt_engines/lora/fp16/1-gpu \
            --lora_plugin float16 \
            --lora_dir tmp/qwen-lora-14b

python ../run.py --engine_dir ./tmp/Qwen/14B/trt_engines/lora/fp16/1-gpu \
              --max_output_len 50 \
              --tokenizer_dir  tmp/Qwen/14B \
              --input_text "test" \
              --lora_task_uids 0 \
              --use_py_session

This patch is just for display, you could create a feature request for official support.

h6353115 commented 5 months ago

Thank you for your response, it has been very helpful to me. I previously tried the method you mentioned with qwen1.0 Lora, and it resulted in errors as it doesn't support 'c_attn'. Next, I plan to verify using this approach with qwen1.5 Lora, and I believe it should work without issues.

h6353115 commented 5 months ago

Add lora support for Qwen or Qwen2 #1432

h6353115 commented 5 months ago

Thank you for your response, it has been very helpful to me. I previously tried the method you mentioned with qwen1.0 Lora, and it resulted in errors as it doesn't support 'c_attn'. Next, I plan to verify using this approach with qwen1.5 Lora, and I believe it should work without issues.

qwen1.5 lora can not work.

yuxianq commented 5 months ago

@h6353115 Can you provide a qwen 1.0/1.5 lora checkpoint to compare with llama lora checkpoint?

h6353115 commented 5 months ago

@h6353115 Can you provide a qwen 1.0/1.5 lora checkpoint to compare with llama lora checkpoint? qwen1.5_14b_lora.zip

qwen1.0_14b_lora.zip

github-actions[bot] commented 3 months ago

This issue is stale because it has been open 30 days with no activity. Remove stale label or comment or this will be closed in 15 days."

yuxianq commented 3 months ago

We have supported LoRA for Qwen now (see https://github.com/NVIDIA/TensorRT-LLM/discussions/1735), LoRA support for Qwen2 is planned.

byshiue commented 3 months ago

Since LoRA for Qwen is supported, close this issue.