Open v-dicicco opened 1 week ago
If the mixtral 8x7B has its own mesuda_model likemedusa-vicuna-7b-v1.3
for vicuna-7b-v1.3
, then we can have a try on enabling meduas for MoE model.
Thanks for the answer! There is a medusa model already trained for Mixtral-Instruct v0.1 here https://huggingface.co/text-generation-inference/Mixtral-8x7B-Instruct-v0.1-medusa for TGI
The format is the same as the original vicuna heads, since I was able to use the heads for miStral from the same project with TRT-LLM (https://huggingface.co/text-generation-inference/Mistral-7B-Instruct-v0.2-medusa).
Do you already see other challenges other than fixing the convert_checkpoint.py
script of Medusa? I'm trying to work on it right now
Thanks for the answer! There is a medusa model already trained for Mixtral-Instruct v0.1 here https://huggingface.co/text-generation-inference/Mixtral-8x7B-Instruct-v0.1-medusa for TGI
The format is the same as the original vicuna heads, since I was able to use the heads for miStral from the same project with TRT-LLM (https://huggingface.co/text-generation-inference/Mistral-7B-Instruct-v0.2-medusa).
Do you already see other challenges other than fixing the
convert_checkpoint.py
script of Medusa? I'm trying to work on it right now
No, We don't investigate the MoE + medusa yet. But my gut tells me that it could be done once we add the missing MoE supporting in convert_checkpoint.py.
The construction of baichuan2-7b medusa engine has been completed. Based on experience, the following suggestions are made:
Thanks for the reply @nv-guomingz @skyCreateXian.
I was able to modify LLaMa's convert_checkpoint.py adding Medusa weights, but I'm obtaining very poor inference performance, in stark contrast to my experience with Mistral 7B.
I proceeded in this way:
MedusaForCausalLm.from_hugging_face()
really similar to the one of LLaMAForCausalLM, that also updates the dict weights of the base model (Mixtral 8x7B) with the Medusa weights provided by load_medusa_hf() and instantiates and loads the weights into a MedusaForCausalLm
object (taking also cares to create the right MedusaConf
with the MedusaForCausalLM
architecture field).trtllm-build
I needed to force silu
as hidden_act
of the medusa heads here, rather than taking its value from the model conf. This was needed since during the conversion of Mixtral in the previous step, TRT-LLM put swiglu
as hidden_act
rather than keeping the original silu
value, that is incompatible with the Medusa heads.To be sure the model is behaving in the expected way, I followed the Debug on E2E Models doc, marking as debugging output the medusa logits printing their prediction at each step, e.g running:
mpirun -np 2 --allow-run-as-root --oversubscribe \
python run.py --engine_dir mixtral_instruct_v1_trt11_medusa \
--tokenizer_dir mixtral_tokenizer \
--max_output_len=14 \
--temperature 1.0 \
--input_text "[INST] Hello! [/INST]" \
--medusa_choices="[[0], [0, 0], [0, 0, 0]]" \
--use_py_session \
--debug_mode
```
[TensorRT-LLM] TensorRT-LLM version: 0.11.0.dev2024061100
[TensorRT-LLM] TensorRT-LLM version: 0.11.0.dev2024061100
[06/23/2024-17:11:58] [TRT-LLM] [W] Implicitly setting MedusaConfig.skip_loading_weights = True
[06/23/2024-17:11:58] [TRT-LLM] [W] Implicitly setting MedusaConfig.mup_width_multiplier = 1.0
[06/23/2024-17:11:58] [TRT-LLM] [I] Set dtype to float16.
[06/23/2024-17:11:58] [TRT-LLM] [I] Set bert_attention_plugin to auto.
[06/23/2024-17:11:58] [TRT-LLM] [I] Set gpt_attention_plugin to auto.
[06/23/2024-17:11:58] [TRT-LLM] [I] Set gemm_plugin to float16.
[06/23/2024-17:11:58] [TRT-LLM] [I] Set gemm_swiglu_plugin to None.
[06/23/2024-17:11:58] [TRT-LLM] [I] Set smooth_quant_gemm_plugin to None.
[06/23/2024-17:11:58] [TRT-LLM] [I] Set identity_plugin to None.
[06/23/2024-17:11:58] [TRT-LLM] [I] Set layernorm_quantization_plugin to None.
[06/23/2024-17:11:58] [TRT-LLM] [I] Set rmsnorm_quantization_plugin to None.
[06/23/2024-17:11:58] [TRT-LLM] [I] Set nccl_plugin to float16.
[06/23/2024-17:11:58] [TRT-LLM] [I] Set lookup_plugin to None.
[06/23/2024-17:11:58] [TRT-LLM] [I] Set lora_plugin to None.
[06/23/2024-17:11:58] [TRT-LLM] [I] Set weight_only_groupwise_quant_matmul_plugin to None.
[06/23/2024-17:11:58] [TRT-LLM] [I] Set weight_only_quant_matmul_plugin to float16.
[06/23/2024-17:11:58] [TRT-LLM] [I] Set quantize_per_token_plugin to False.
[06/23/2024-17:11:58] [TRT-LLM] [I] Set quantize_tensor_plugin to False.
[06/23/2024-17:11:58] [TRT-LLM] [I] Set moe_plugin to auto.
[06/23/2024-17:11:58] [TRT-LLM] [I] Set mamba_conv1d_plugin to auto.
[06/23/2024-17:11:58] [TRT-LLM] [I] Set context_fmha to True. [71/1686]
[06/23/2024-17:11:58] [TRT-LLM] [I] Set context_fmha_fp32_acc to False.
[06/23/2024-17:11:58] [TRT-LLM] [I] Set paged_kv_cache to True.
[06/23/2024-17:11:58] [TRT-LLM] [I] Set remove_input_padding to True.
[06/23/2024-17:11:58] [TRT-LLM] [I] Set use_custom_all_reduce to False.
[06/23/2024-17:11:58] [TRT-LLM] [I] Set reduce_fusion to False.
[06/23/2024-17:11:58] [TRT-LLM] [I] Set multi_block_mode to False.
[06/23/2024-17:11:58] [TRT-LLM] [I] Set enable_xqa to True.
[06/23/2024-17:11:58] [TRT-LLM] [I] Set attention_qk_half_accumulation to False.
[06/23/2024-17:11:58] [TRT-LLM] [I] Set tokens_per_block to 64.
[06/23/2024-17:11:58] [TRT-LLM] [I] Set use_paged_context_fmha to False.
[06/23/2024-17:11:58] [TRT-LLM] [I] Set use_fp8_context_fmha to False.
[06/23/2024-17:11:58] [TRT-LLM] [I] Set multiple_profiles to False.
[06/23/2024-17:11:58] [TRT-LLM] [I] Set paged_state to True.
[06/23/2024-17:11:58] [TRT-LLM] [I] Set streamingllm to False.
[06/23/2024-17:11:58] [TRT-LLM] [W] Implicitly setting MedusaConfig.skip_loading_weights = True
[06/23/2024-17:11:58] [TRT-LLM] [W] Implicitly setting MedusaConfig.mup_width_multiplier = 1.0
[06/23/2024-17:11:58] [TRT-LLM] [I] Set dtype to float16.
[06/23/2024-17:11:58] [TRT-LLM] [I] Set bert_attention_plugin to auto.
[06/23/2024-17:11:58] [TRT-LLM] [I] Set gpt_attention_plugin to auto.
[06/23/2024-17:11:58] [TRT-LLM] [I] Set gemm_plugin to float16.
[06/23/2024-17:11:58] [TRT-LLM] [I] Set gemm_swiglu_plugin to None.
[06/23/2024-17:11:58] [TRT-LLM] [I] Set smooth_quant_gemm_plugin to None.
[06/23/2024-17:11:58] [TRT-LLM] [I] Set identity_plugin to None.
[06/23/2024-17:11:58] [TRT-LLM] [I] Set layernorm_quantization_plugin to None.
[06/23/2024-17:11:58] [TRT-LLM] [I] Set rmsnorm_quantization_plugin to None.
[06/23/2024-17:11:58] [TRT-LLM] [I] Set nccl_plugin to float16.
[06/23/2024-17:11:58] [TRT-LLM] [I] Set lookup_plugin to None.
[06/23/2024-17:11:58] [TRT-LLM] [I] Set lora_plugin to None.
[06/23/2024-17:11:58] [TRT-LLM] [I] Set weight_only_groupwise_quant_matmul_plugin to None.
[06/23/2024-17:11:58] [TRT-LLM] [I] Set weight_only_quant_matmul_plugin to float16.
[06/23/2024-17:11:58] [TRT-LLM] [I] Set quantize_per_token_plugin to False.
[06/23/2024-17:11:58] [TRT-LLM] [I] Set quantize_tensor_plugin to False.
[06/23/2024-17:11:58] [TRT-LLM] [I] Set moe_plugin to auto.
[06/23/2024-17:11:58] [TRT-LLM] [I] Set mamba_conv1d_plugin to auto.
[06/23/2024-17:11:58] [TRT-LLM] [I] Set context_fmha to True.
[06/23/2024-17:11:58] [TRT-LLM] [I] Set context_fmha_fp32_acc to False.
[06/23/2024-17:11:58] [TRT-LLM] [I] Set paged_kv_cache to True.
[06/23/2024-17:11:58] [TRT-LLM] [I] Set remove_input_padding to True.
[06/23/2024-17:11:58] [TRT-LLM] [I] Set use_custom_all_reduce to False.
[06/23/2024-17:11:58] [TRT-LLM] [I] Set reduce_fusion to False.
[06/23/2024-17:11:58] [TRT-LLM] [I] Set multi_block_mode to False.
[06/23/2024-17:11:58] [TRT-LLM] [I] Set enable_xqa to True.
[06/23/2024-17:11:58] [TRT-LLM] [I] Set attention_qk_half_accumulation to False.
[06/23/2024-17:11:58] [TRT-LLM] [I] Set tokens_per_block to 64.
[06/23/2024-17:11:58] [TRT-LLM] [I] Set use_paged_context_fmha to False.
[06/23/2024-17:11:58] [TRT-LLM] [I] Set use_fp8_context_fmha to False.
[06/23/2024-17:11:58] [TRT-LLM] [I] Set multiple_profiles to False.
[06/23/2024-17:11:58] [TRT-LLM] [I] Set paged_state to True.
[06/23/2024-17:11:58] [TRT-LLM] [I] Set streamingllm to False.
[06/23/2024-17:12:03] [TRT] [I] [MemUsageChange] TensorRT-managed allocation in IExecutionContext creation: CPU +0, GPU +0, now: CPU 0, GPU 22948 (MiB)
[06/23/2024-17:12:03] [TRT] [I] [MemUsageChange] TensorRT-managed allocation in IExecutionContext creation: CPU +0, GPU +0, now: CPU 0, GPU 22948 (MiB)
[06/23/2024-17:12:03] [TRT] [I] [MemUsageChange] TensorRT-managed allocation in IExecutionContext creation: CPU +0, GPU +0, now: CPU 0, GPU 22948 (MiB)
[06/23/2024-17:12:03] [TRT-LLM] [W] The paged KV cache in Python runtime is experimental. For performance and correctness, please, use C++ runtime.
[06/23/2024-17:12:03] [TRT] [I] [MemUsageChange] TensorRT-managed allocation in IExecutionContext creation: CPU +0, GPU +0, now: CPU 0, GPU 22948 (MiB)
[06/23/2024-17:12:03] [TRT-LLM] [W] The paged KV cache in Python runtime is experimental. For performance and correctness, please, use C++ runtime.
[06/23/2024-17:12:03] [TRT-LLM] [I] Load engine takes: 15.598329305648804 sec
[06/23/2024-17:12:03] [TRT-LLM] [I] Load engine takes: 15.83333158493042 sec
Step: 0 ====================
logits:
0 [22557, 0, 0, 0] - ['▁Hello', ' [INST] Hello! [/INST]"
Output [Text 0 Beam 0]: "Hello! It's nice to meet you. How can I help"
```
Note:
logits
and the 3 medusa_heads
, with also the new_tokens produced in the stepmax_output_len=14
of tokens, and the request was fulfilled in "just" 5 generation steps (so the heads are working correctly in this example)Despite this, if we try to benchmark the model with and without medusa, I'm obtaining very poor performance:
The latency of the model without Medusa is lower despite the one with Medusa requires just 5 generation steps rather than 14.
What can be the cause? I'm using 2xA40 (48GB) with Mixtral int8 and TP 2.
Hi @v-dicicco , I read the above step-by-step of enabling Medusa for Mixtral 8X7b and I think its correct IMO. For perf issue, is it possible to sample the perf data via nsys and share us for further analysis?
@nv-guomingz I've computed the perf data and done additional tests.
I've tried on a 1xH100, trying to factor out multi gpu overhead (w.r.t. the previous test with 2xA40) and using a powerful GPU. Unfortunately I'm obtaining overall similar results also in this setting.
Here I'm using [INST] Hello [/INST]\n
as prompt and max_output_len=35
, with Mixtral int8 weight-only quantized and batch_size=1
. Also in this case, if I print details for each generation step using the python session, I can see the request with Medusa is fulfilled with just 15 steps (rather than 35), this is also visibile in the perf charts.
However if we benchmark both models:
Here is a zip containing perf data acquired with nsys
for the previous two runs. Note that they contain the full run, so the "warmup" plus 10 generation of the same prompt. They should look like this:
A single Mixtral+Medusa generation:
A single Mixtral generation:
I'm trying to analyse the charts, but I'm really looking forward to your feedback...let me know If there is anything else I can do to help!
@v-dicicco Did you train your own medusa head? If the medusa head is not guessed correctly, medusa will be slower than the base model because speculative decoding verification also requires costs
@skyCreateXian for Mixtral 8x7B, I'm not using my own medusa heads but this ones: Mixtral-8x7B-Instruct-v0.1-medusa
I thought the problem could be related to the quality of the heads, but according to my tests they are working correctly (e.g: in my previous message, in the "Output" details, you can see the heads are predicting tokens then accepted by the model). Also, with miStral I've used public medusa heads from the same project of the ones for Mixtral, without issues.
I'm starting to think it could be related to quantization (my tests with mistral were in fp16) while here Mixtral is int8, or issues related to the MoE architecture. Hope to get some feedback also by @nv-guomingz and team
Thanks @v-dicicco for providing nsys files, we'll take a look into the details.
Hello! Does TensorRT-LLM supports Medusa with Mixtral 8x7B?
My understanding is that right now the Medusa convert_checkpoint.py doesn't support Mixtral (e.g: it lacks the
moe
config and also other MoE related arguments contained in the LLama conversion script) but I have the feeling it should (in theory) work sinceMedusaForCausalLm
is based on LLaMAForCausalLM, and that theconvert_checkpoint.py
of Medusa can be aligned to the one used by LLama (for some specific configurations).Would be helpful any hints in this direction :)
Thanks!