Closed ValeGian closed 2 months ago
Hi @ValeGian ,
Thank you for your attention. This is an expected behavior, we only support --quant_medusa_head=False
at present.
The current quantization library does not support the quantization of medusa head. We are working on the development and support it ASAP.
Hi @yweng0828 ,
Thank you for the clarification!
I appreciate the work you're doing on this. Just to confirm, will the support for Medusa Head quantization be included in TensorRT-LLM v0.12.0, or should we expect it in a subsequent release?
Thanks again for your help!
Hi @ValeGian , Sorry, this feature will not be included in TensorRT-LLM v0.12.0. The expected completion time is September or October. There is a Ninja way to quantize the medusa head, you can use other methods to quantize this part of the weight, and then hack it into the final checkpoint file by aligning the key-value. Building engine and execution are supported in quantized medusa head.
Thanks @yweng0828 for the heads-up!
I'll give the Ninja method a try for now and keep an eye out for the official support later this year.
Appreciate the help!
System Info
CPU architecture: x86_64 GPU: NVIDIA H100 Libraries TensorRT-LLM: v0.11.0 TensorRT: 10.1.0 Modelopt: 0.13.1 CUDA: 12.3 NVIDIA driver version: 535.129.03
Issue
Hello, I'm experiencing a failure when building a Mixtral + Medusa heads FP8 checkpoint (weights and KV Cache).
Reproduction
Steps to reproduce the behavior:
quantize.py --model_dir=<MODEL DIR> --dtype=float16 --tp_size=1 --output_dir=<CHECKPOINT DIR> --qformat=fp8 --kv_cache_dtype=fp8 --calib_dataset=<CALIB DATASET> --calib_size=512 --batch_size=8 --calib_max_seq_length=1024 --num_medusa_heads=2 --num_medusa_layers=1 --max_draft_len=2 --medusa_model_dir=<MEDUSA MODEL DIR> --quant_medusa_head
trtllm-build --checkpoint_dir=<CHECKPOINT DIR> --max_beam_width=1 --max_seq_len=32768 --max_input_len=32368 --max_num_tokens=32768 --max_batch_size=4 --context_fmha=enable --use_custom_all_reduce=disable --output_dir=<OUT DIR> --use_fp8_context_fmha=disable --speculative_decoding_mode=medusa
Expected behavior
The expected output would be a correctly built Mixtral + Medusa heads FP8 engine.
Actual behavior
The checkpoint generation works and the quantized model is exported, but when running trtllm-build, a crash occurs with the following message:
The same exact trtllm-build command works correctly if executed starting from a checkpoint obtained by running the same exact quantize.py command with the only difference of removing the
--quant_medusa_head
flag, correctly building the engine.