NVIDIA / TensorRT-LLM

TensorRT-LLM provides users with an easy-to-use Python API to define Large Language Models (LLMs) and build TensorRT engines that contain state-of-the-art optimizations to perform inference efficiently on NVIDIA GPUs. TensorRT-LLM also contains components to create Python and C++ runtimes that execute those TensorRT engines.
https://nvidia.github.io/TensorRT-LLM
Apache License 2.0
7.34k stars 794 forks source link

Medusa with Mixtral 8x7B #1798

Open v-dicicco opened 1 week ago

v-dicicco commented 1 week ago

Hello! Does TensorRT-LLM supports Medusa with Mixtral 8x7B?

My understanding is that right now the Medusa convert_checkpoint.py doesn't support Mixtral (e.g: it lacks the moe config and also other MoE related arguments contained in the LLama conversion script) but I have the feeling it should (in theory) work since MedusaForCausalLm is based on LLaMAForCausalLM, and that the convert_checkpoint.py of Medusa can be aligned to the one used by LLama (for some specific configurations).

Would be helpful any hints in this direction :)

Thanks!

nv-guomingz commented 1 week ago

If the mixtral 8x7B has its own mesuda_model likemedusa-vicuna-7b-v1.3 for vicuna-7b-v1.3, then we can have a try on enabling meduas for MoE model.

v-dicicco commented 1 week ago

Thanks for the answer! There is a medusa model already trained for Mixtral-Instruct v0.1 here https://huggingface.co/text-generation-inference/Mixtral-8x7B-Instruct-v0.1-medusa for TGI

The format is the same as the original vicuna heads, since I was able to use the heads for miStral from the same project with TRT-LLM (https://huggingface.co/text-generation-inference/Mistral-7B-Instruct-v0.2-medusa).

Do you already see other challenges other than fixing the convert_checkpoint.py script of Medusa? I'm trying to work on it right now

nv-guomingz commented 1 week ago

Thanks for the answer! There is a medusa model already trained for Mixtral-Instruct v0.1 here https://huggingface.co/text-generation-inference/Mixtral-8x7B-Instruct-v0.1-medusa for TGI

The format is the same as the original vicuna heads, since I was able to use the heads for miStral from the same project with TRT-LLM (https://huggingface.co/text-generation-inference/Mistral-7B-Instruct-v0.2-medusa).

Do you already see other challenges other than fixing the convert_checkpoint.py script of Medusa? I'm trying to work on it right now

No, We don't investigate the MoE + medusa yet. But my gut tells me that it could be done once we add the missing MoE supporting in convert_checkpoint.py.

skyCreateXian commented 1 week ago

The construction of baichuan2-7b medusa engine has been completed. Based on experience, the following suggestions are made:

  1. You can refer to it https://github.com/NVIDIA/TensorRT-LLM/tree/2a115dae84f13daaa54727534daa837c534eceb4/examples/mixtral Checkpoint of llama called
  2. According to medusa/checkpoint 2.1 Add functions such as medusa head load 2.2 Modify the content of medusa config in the code. Required options: 'architecture': 'MedusaForCausalLM'. Modify other configurations as needed
v-dicicco commented 5 days ago

Thanks for the reply @nv-guomingz @skyCreateXian.

I was able to modify LLaMa's convert_checkpoint.py adding Medusa weights, but I'm obtaining very poor inference performance, in stark contrast to my experience with Mistral 7B.

I proceeded in this way:

  1. Created a custom MedusaForCausalLm.from_hugging_face() really similar to the one of LLaMAForCausalLM, that also updates the dict weights of the base model (Mixtral 8x7B) with the Medusa weights provided by load_medusa_hf() and instantiates and loads the weights into a MedusaForCausalLm object (taking also cares to create the right MedusaConf with the MedusaForCausalLM architecture field).
  2. During engine creation with trtllm-build I needed to force silu as hidden_act of the medusa heads here, rather than taking its value from the model conf. This was needed since during the conversion of Mixtral in the previous step, TRT-LLM put swiglu as hidden_act rather than keeping the original silu value, that is incompatible with the Medusa heads.

To be sure the model is behaving in the expected way, I followed the Debug on E2E Models doc, marking as debugging output the medusa logits printing their prediction at each step, e.g running:

mpirun -np 2 --allow-run-as-root --oversubscribe \
    python run.py --engine_dir mixtral_instruct_v1_trt11_medusa \
                     --tokenizer_dir mixtral_tokenizer \
                     --max_output_len=14 \
                     --temperature 1.0 \
                     --input_text "[INST] Hello! [/INST]" \
                     --medusa_choices="[[0], [0, 0], [0, 0, 0]]" \
                     --use_py_session \
                     --debug_mode
Output

``` [TensorRT-LLM] TensorRT-LLM version: 0.11.0.dev2024061100 [TensorRT-LLM] TensorRT-LLM version: 0.11.0.dev2024061100 [06/23/2024-17:11:58] [TRT-LLM] [W] Implicitly setting MedusaConfig.skip_loading_weights = True [06/23/2024-17:11:58] [TRT-LLM] [W] Implicitly setting MedusaConfig.mup_width_multiplier = 1.0 [06/23/2024-17:11:58] [TRT-LLM] [I] Set dtype to float16. [06/23/2024-17:11:58] [TRT-LLM] [I] Set bert_attention_plugin to auto. [06/23/2024-17:11:58] [TRT-LLM] [I] Set gpt_attention_plugin to auto. [06/23/2024-17:11:58] [TRT-LLM] [I] Set gemm_plugin to float16. [06/23/2024-17:11:58] [TRT-LLM] [I] Set gemm_swiglu_plugin to None. [06/23/2024-17:11:58] [TRT-LLM] [I] Set smooth_quant_gemm_plugin to None. [06/23/2024-17:11:58] [TRT-LLM] [I] Set identity_plugin to None. [06/23/2024-17:11:58] [TRT-LLM] [I] Set layernorm_quantization_plugin to None. [06/23/2024-17:11:58] [TRT-LLM] [I] Set rmsnorm_quantization_plugin to None. [06/23/2024-17:11:58] [TRT-LLM] [I] Set nccl_plugin to float16. [06/23/2024-17:11:58] [TRT-LLM] [I] Set lookup_plugin to None. [06/23/2024-17:11:58] [TRT-LLM] [I] Set lora_plugin to None. [06/23/2024-17:11:58] [TRT-LLM] [I] Set weight_only_groupwise_quant_matmul_plugin to None. [06/23/2024-17:11:58] [TRT-LLM] [I] Set weight_only_quant_matmul_plugin to float16. [06/23/2024-17:11:58] [TRT-LLM] [I] Set quantize_per_token_plugin to False. [06/23/2024-17:11:58] [TRT-LLM] [I] Set quantize_tensor_plugin to False. [06/23/2024-17:11:58] [TRT-LLM] [I] Set moe_plugin to auto. [06/23/2024-17:11:58] [TRT-LLM] [I] Set mamba_conv1d_plugin to auto. [06/23/2024-17:11:58] [TRT-LLM] [I] Set context_fmha to True. [71/1686] [06/23/2024-17:11:58] [TRT-LLM] [I] Set context_fmha_fp32_acc to False. [06/23/2024-17:11:58] [TRT-LLM] [I] Set paged_kv_cache to True. [06/23/2024-17:11:58] [TRT-LLM] [I] Set remove_input_padding to True. [06/23/2024-17:11:58] [TRT-LLM] [I] Set use_custom_all_reduce to False. [06/23/2024-17:11:58] [TRT-LLM] [I] Set reduce_fusion to False. [06/23/2024-17:11:58] [TRT-LLM] [I] Set multi_block_mode to False. [06/23/2024-17:11:58] [TRT-LLM] [I] Set enable_xqa to True. [06/23/2024-17:11:58] [TRT-LLM] [I] Set attention_qk_half_accumulation to False. [06/23/2024-17:11:58] [TRT-LLM] [I] Set tokens_per_block to 64. [06/23/2024-17:11:58] [TRT-LLM] [I] Set use_paged_context_fmha to False. [06/23/2024-17:11:58] [TRT-LLM] [I] Set use_fp8_context_fmha to False. [06/23/2024-17:11:58] [TRT-LLM] [I] Set multiple_profiles to False. [06/23/2024-17:11:58] [TRT-LLM] [I] Set paged_state to True. [06/23/2024-17:11:58] [TRT-LLM] [I] Set streamingllm to False. [06/23/2024-17:11:58] [TRT-LLM] [W] Implicitly setting MedusaConfig.skip_loading_weights = True [06/23/2024-17:11:58] [TRT-LLM] [W] Implicitly setting MedusaConfig.mup_width_multiplier = 1.0 [06/23/2024-17:11:58] [TRT-LLM] [I] Set dtype to float16. [06/23/2024-17:11:58] [TRT-LLM] [I] Set bert_attention_plugin to auto. [06/23/2024-17:11:58] [TRT-LLM] [I] Set gpt_attention_plugin to auto. [06/23/2024-17:11:58] [TRT-LLM] [I] Set gemm_plugin to float16. [06/23/2024-17:11:58] [TRT-LLM] [I] Set gemm_swiglu_plugin to None. [06/23/2024-17:11:58] [TRT-LLM] [I] Set smooth_quant_gemm_plugin to None. [06/23/2024-17:11:58] [TRT-LLM] [I] Set identity_plugin to None. [06/23/2024-17:11:58] [TRT-LLM] [I] Set layernorm_quantization_plugin to None. [06/23/2024-17:11:58] [TRT-LLM] [I] Set rmsnorm_quantization_plugin to None. [06/23/2024-17:11:58] [TRT-LLM] [I] Set nccl_plugin to float16. [06/23/2024-17:11:58] [TRT-LLM] [I] Set lookup_plugin to None. [06/23/2024-17:11:58] [TRT-LLM] [I] Set lora_plugin to None. [06/23/2024-17:11:58] [TRT-LLM] [I] Set weight_only_groupwise_quant_matmul_plugin to None. [06/23/2024-17:11:58] [TRT-LLM] [I] Set weight_only_quant_matmul_plugin to float16. [06/23/2024-17:11:58] [TRT-LLM] [I] Set quantize_per_token_plugin to False. [06/23/2024-17:11:58] [TRT-LLM] [I] Set quantize_tensor_plugin to False. [06/23/2024-17:11:58] [TRT-LLM] [I] Set moe_plugin to auto. [06/23/2024-17:11:58] [TRT-LLM] [I] Set mamba_conv1d_plugin to auto. [06/23/2024-17:11:58] [TRT-LLM] [I] Set context_fmha to True. [06/23/2024-17:11:58] [TRT-LLM] [I] Set context_fmha_fp32_acc to False. [06/23/2024-17:11:58] [TRT-LLM] [I] Set paged_kv_cache to True. [06/23/2024-17:11:58] [TRT-LLM] [I] Set remove_input_padding to True. [06/23/2024-17:11:58] [TRT-LLM] [I] Set use_custom_all_reduce to False. [06/23/2024-17:11:58] [TRT-LLM] [I] Set reduce_fusion to False. [06/23/2024-17:11:58] [TRT-LLM] [I] Set multi_block_mode to False. [06/23/2024-17:11:58] [TRT-LLM] [I] Set enable_xqa to True. [06/23/2024-17:11:58] [TRT-LLM] [I] Set attention_qk_half_accumulation to False. [06/23/2024-17:11:58] [TRT-LLM] [I] Set tokens_per_block to 64. [06/23/2024-17:11:58] [TRT-LLM] [I] Set use_paged_context_fmha to False. [06/23/2024-17:11:58] [TRT-LLM] [I] Set use_fp8_context_fmha to False. [06/23/2024-17:11:58] [TRT-LLM] [I] Set multiple_profiles to False. [06/23/2024-17:11:58] [TRT-LLM] [I] Set paged_state to True. [06/23/2024-17:11:58] [TRT-LLM] [I] Set streamingllm to False. [06/23/2024-17:12:03] [TRT] [I] [MemUsageChange] TensorRT-managed allocation in IExecutionContext creation: CPU +0, GPU +0, now: CPU 0, GPU 22948 (MiB) [06/23/2024-17:12:03] [TRT] [I] [MemUsageChange] TensorRT-managed allocation in IExecutionContext creation: CPU +0, GPU +0, now: CPU 0, GPU 22948 (MiB) [06/23/2024-17:12:03] [TRT] [I] [MemUsageChange] TensorRT-managed allocation in IExecutionContext creation: CPU +0, GPU +0, now: CPU 0, GPU 22948 (MiB) [06/23/2024-17:12:03] [TRT-LLM] [W] The paged KV cache in Python runtime is experimental. For performance and correctness, please, use C++ runtime. [06/23/2024-17:12:03] [TRT] [I] [MemUsageChange] TensorRT-managed allocation in IExecutionContext creation: CPU +0, GPU +0, now: CPU 0, GPU 22948 (MiB) [06/23/2024-17:12:03] [TRT-LLM] [W] The paged KV cache in Python runtime is experimental. For performance and correctness, please, use C++ runtime. [06/23/2024-17:12:03] [TRT-LLM] [I] Load engine takes: 15.598329305648804 sec [06/23/2024-17:12:03] [TRT-LLM] [I] Load engine takes: 15.83333158493042 sec Step: 0 ==================== logits: 0 [22557, 0, 0, 0] - ['▁Hello', '', '', ''] medusa_heads: 0 [28808, 661, 28742, 0] - ['!', '▁It', "'", ''] 1 [0, 0, 0, 0] - ['', '', '', ''] 2 [0, 0, 0, 0] - ['', '', '', ''] Step: 1 ==================== logits: 0 [28808, 661, 28742, 28713] - ['!', '▁It', "'", 's'] medusa_heads: 0 [661, 28742, 28713, 5171] - ['▁It', "'", 's', '▁nice'] 1 [28742, 315, 5171, 298] - ["'", '▁I', '▁nice', '▁to'] 2 [28713, 5171, 298, 2647] - ['s', '▁nice', '▁to', '▁meet'] new_tokens: tensor([[28808, 661, 28742, 28713]], device='cuda:0', dtype=torch.int32) - ['!', '▁It', "'", 's'] Step: 2 ==================== logits: 0 [5171, 298, 2647, 368] - ['▁nice', '▁to', '▁meet', '▁you'] medusa_heads: 0 [298, 2647, 368, 28723] - ['▁to', '▁meet', '▁you', '.'] 1 [2647, 368, 28723, 1602] - ['▁meet', '▁you', '.', '▁How'] 2 [368, 28723, 1602, 736] - ['▁you', '.', '▁How', '▁there'] new_tokens: tensor([[5171, 298, 2647, 368]], device='cuda:0', dtype=torch.int32) - ['▁nice', '▁to', '▁meet', '▁you'] Step: 3 ==================== logits: 0 [28723, 1602, 541, 541] - ['.', '▁How', '▁can', '▁can'] medusa_heads: 0 [1602, 736, 315, 541] - ['▁How', '▁there', '▁I', '▁can'] 1 [736, 315, 1316, 541] - ['▁there', '▁I', '▁help', '▁can'] 2 [1545, 1316, 368, 6926] - ['▁something', '▁help', '▁you', '▁Where'] new_tokens: tensor([[28723, 1602, 541]], device='cuda:0', dtype=torch.int32) - ['.', '▁How', '▁can'] Step: 4 ==================== logits: 0 [315, 1316, 368, 3154] - ['▁I', '▁help', '▁you', '▁today'] medusa_heads: 0 [1316, 368, 3154, 28804] - ['▁help', '▁you', '▁today', '?'] 1 [368, 3154, 28804, 1691] - ['▁you', '▁today', '?', '▁Is'] 2 [3154, 28804, 1691, 736] - ['▁today', '?', '▁Is', '▁there'] new_tokens: tensor([[ 315, 1316, 368, 3154]], device='cuda:0', dtype=torch.int32) - ['▁I', '▁help', '▁you', '▁today'] Input [Text 0]: " [INST] Hello! [/INST]" Output [Text 0 Beam 0]: "Hello! It's nice to meet you. How can I help" ```

Note:

  1. At each generation step, I'm printing the argmax of both the Mixtral logits and the 3 medusa_heads, with also the new_tokens produced in the step
  2. We requested a max_output_len=14 of tokens, and the request was fulfilled in "just" 5 generation steps (so the heads are working correctly in this example)

Despite this, if we try to benchmark the model with and without medusa, I'm obtaining very poor performance:

Benchmark C++ session w/ Medusa: latency 0.245 sec ``` mpirun -np 2 --allow-run-as-root --oversubscribe \ python run.py --engine_dir mixtral_instruct_v1_trt11_medusa \ --tokenizer_dir mixtral_tokenizer \ --max_output_len=14 \ --temperature 1.0 \ --input_text "[INST] Hello! [/INST]" \ --medusa_choices="[[0], [0, 0], [0, 0, 0]]" \ --run_profiling [TensorRT-LLM] TensorRT-LLM version: 0.11.0.dev2024061100 [TensorRT-LLM] TensorRT-LLM version: 0.11.0.dev2024061100 [06/23/2024-17:35:22] [TRT-LLM] [I] Load engine takes: 19.492470264434814 sec [06/23/2024-17:35:22] [TRT-LLM] [I] Load engine takes: 19.49226713180542 sec batch_size: 1, avg latency of 10 iterations: : 1.5497207641601562e-05 sec Input [Text 0]: " [INST] Hello! [/INST]" Output [Text 0 Beam 0]: "Hello! It's nice to meet you. How can How can" batch_size: 1, avg latency of 10 iterations: : 0.2458343505859375 sec ```
Benchmark C++ session w/o Medusa: latency 0.226 sec ``` mpirun -np 2 --allow-run-as-root --oversubscribe \ python run.py --engine_dir mixtral_instruct_v1_trt11 \ --tokenizer_dir mixtral_tokenizer \ --max_output_len=14 \ --temperature 1.0 \ --input_text "[INST] Hello! [/INST]" \ --run_profiling [TensorRT-LLM] TensorRT-LLM version: 0.11.0.dev2024061100 [TensorRT-LLM] TensorRT-LLM version: 0.11.0.dev2024061100 [06/23/2024-17:36:32] [TRT-LLM] [I] Load engine takes: 18.678980827331543 sec [06/23/2024-17:36:32] [TRT-LLM] [I] Load engine takes: 18.678552389144897 sec batch_size: 1, avg latency of 10 iterations: : 1.5282630920410155e-05 sec Input [Text 0]: " [INST] Hello! [/INST]" Output [Text 0 Beam 0]: "Hello! It's nice to meet you. How can I help" batch_size: 1, avg latency of 10 iterations: : 0.22646000385284423 sec ```

The latency of the model without Medusa is lower despite the one with Medusa requires just 5 generation steps rather than 14.

What can be the cause? I'm using 2xA40 (48GB) with Mixtral int8 and TP 2.

nv-guomingz commented 4 days ago

Hi @v-dicicco , I read the above step-by-step of enabling Medusa for Mixtral 8X7b and I think its correct IMO. For perf issue, is it possible to sample the perf data via nsys and share us for further analysis?

v-dicicco commented 2 days ago

@nv-guomingz I've computed the perf data and done additional tests.

I've tried on a 1xH100, trying to factor out multi gpu overhead (w.r.t. the previous test with 2xA40) and using a powerful GPU. Unfortunately I'm obtaining overall similar results also in this setting.

Here I'm using [INST] Hello [/INST]\n as prompt and max_output_len=35, with Mixtral int8 weight-only quantized and batch_size=1. Also in this case, if I print details for each generation step using the python session, I can see the request with Medusa is fulfilled with just 15 steps (rather than 35), this is also visibile in the perf charts.

However if we benchmark both models:

Benchmark C++ session w/ Medusa: latency 0.462 sec ``` python3.10 TensorRT-LLM/examples/run.py --engine_dir medusa_engine/engine \ --tokenizer_dir mixtral-instruct \ --max_output_len=35 \ --temperature 1.0 \ --input_text "$PROMPT" \ --medusa_choices "[[0], [0, 0], [0, 0, 0]]" [TensorRT-LLM] TensorRT-LLM version: 0.11.0.dev2024061100 [06/26/2024-19:14:18] [TRT-LLM] [I] Load engine takes: 36.19824481010437 sec Input [Text 0]: " [INST] Hello [/INST]\n" Output [Text 0 Beam 0]: "Hello! How can I help you today? If you have any questions about a specific topic, feel free to ask. I'm here to provide information and answer your questions" batch_size: 1, avg latency of 10 iterations: : 0.46210289001464844 sec ```
Benchmark C++ session w/o Medusa: latency 0.383 sec ``` python3.10 TensorRT-LLM/examples/run.py --engine_dir standard_engine/engine \ --tokenizer_dir mixtral-instruct \ --max_output_len=35 \ --temperature 1.0 \ --input_text "[INST] Hello [/INST]\n" \ --run_profiling [TensorRT-LLM] TensorRT-LLM version: 0.11.0.dev2024061100 [06/26/2024-19:15:58] [TRT-LLM] [I] Load engine takes: 32.397374868392944 sec Input [Text 0]: " [INST] Hello [/INST]\n" Output [Text 0 Beam 0]: "Hello! How can I help you today? If you have any questions about a specific topic, feel free to ask. I'm here to provide information and answer your questions" batch_size: 1, avg latency of 10 iterations: : 0.38388829231262206 sec ```

Here is a zip containing perf data acquired with nsys for the previous two runs. Note that they contain the full run, so the "warmup" plus 10 generation of the same prompt. They should look like this:

  • A single Mixtral+Medusa generation:

    image
  • A single Mixtral generation:

    image

I'm trying to analyse the charts, but I'm really looking forward to your feedback...let me know If there is anything else I can do to help!

skyCreateXian commented 1 day ago

@v-dicicco Did you train your own medusa head? If the medusa head is not guessed correctly, medusa will be slower than the base model because speculative decoding verification also requires costs

v-dicicco commented 18 hours ago

@skyCreateXian for Mixtral 8x7B, I'm not using my own medusa heads but this ones: Mixtral-8x7B-Instruct-v0.1-medusa

I thought the problem could be related to the quality of the heads, but according to my tests they are working correctly (e.g: in my previous message, in the "Output" details, you can see the heads are predicting tokens then accepted by the model). Also, with miStral I've used public medusa heads from the same project of the ones for Mixtral, without issues.

I'm starting to think it could be related to quantization (my tests with mistral were in fp16) while here Mixtral is int8, or issues related to the MoE architecture. Hope to get some feedback also by @nv-guomingz and team

nv-guomingz commented 16 hours ago

Thanks @v-dicicco for providing nsys files, we'll take a look into the details.