NVIDIA / TensorRT-LLM

TensorRT-LLM provides users with an easy-to-use Python API to define Large Language Models (LLMs) and build TensorRT engines that contain state-of-the-art optimizations to perform inference efficiently on NVIDIA GPUs. TensorRT-LLM also contains components to create Python and C++ runtimes that execute those TensorRT engines.
https://nvidia.github.io/TensorRT-LLM
Apache License 2.0
7.34k stars 794 forks source link

multi block mode performance issue #1548

Open littletomatodonkey opened 1 month ago

littletomatodonkey commented 1 month ago

Hi, I tested multi-block-mode for TRT-LLM based on Yi-6B model (llama-structure), the performance is as follows. It seems that

  1. Multi-block-mode (MBM) works only on long input and long output, if max_new_tokens is 1, MBM does not work.
  2. MBM does not work if input token length <= 4096, which is different from flash decoding?

Could you please tell me does the conclusion meets the expection of TRT-LLM multi-block-mode? Thanks!

input_len=4096, output_len=512
wo   mbm: latency avg = 3.97, tokens per second avg = 128.85
with mbm: latency avg = 3.99, tokens per second avg = 128.35

input_len=8192, output_len=512
wo   mbm: latency avg = 5.03, tokens per second avg = 101.80
with mbm: latency avg = 4.77, tokens per second avg = 107.29

input_len=16384, output_len=512
wo   mbm: latency avg = 7.46, tokens per second avg = 68.61
with mbm: latency avg = 6.60, tokens per second avg = 77.59

input_len=4096, output_len=2048
wo   mbm: latency avg = 15.67, tokens per second avg = 130.70
with mbm: latency avg = 15.35, tokens per second avg = 133.41

input_len=16384, output_len=2048
wo   mbm: latency avg = 25.80, tokens per second avg = 79.37
with mbm: latency avg = 22.00, tokens per second avg = 93.07

input_len=8192, output_len=1
wo   mbm: latency avg = 0.59, tokens per second avg = 1.70
with mbm: latency avg = 0.59, tokens per second avg = 1.70

input_len=16384, output_len=1
wo   mbm: latency avg = 1.54, tokens per second avg = 0.65
with mbm: latency avg = 1.54, tokens per second avg = 0.65

The convert script is as follows.


tmp_dir=$(mktemp -d)

echo "tmp_dir: ${tmp_dir}"

python convert_checkpoint.py \
--model_dir ${hf_model_dir} \
--output_dir ${tmp_dir} \
--dtype float16 \
--use_weight_only

trtllm-build \
--checkpoint_dir ${tmp_dir} \
--output_dir ${trt_model_dir} \
--remove_input_padding "enable" \
--context_fmha "enable" \
--gemm_plugin="float16" \
--gpt_attention_plugin "float16" \
--max_batch_size 16 \
--max_input_len 16384 \
--max_output_len 2048 \
--paged_kv_cache enable \
--use_paged_context_fmha enable
# if you add multi_block mode, you need also set `--multi_block_mode enable`
byshiue commented 1 month ago

Multi-block-mode (MBM) works only on long input and long output, if max_new_tokens is 1, MBM does not work.

It is expected because we only use multi-block in generation phase (generating new token). In context phase, we have enough blocks to run in parallel and we don't need to use multi-block.

MBM does not work if input token length <= 4096, which is different from flash decoding?

Leave to other to help replying.

PerkzZheng commented 1 month ago

MBM does not work if input token length <= 4096, which is different from flash decoding?

the idea is that we will first try to fully utilize one SM before using more blocks per sequence, so there is a threshold to determine whether we need to enable the multi-block mode or not. you can always finetune the performance by setting TRTLLM_ENABLE_MMHA_MULTI_BLOCK_DEBUG=1 TRTLLM_MMHA_BLOCKS_PER_SEQUENCE=4, TRTLLM_MMHA_BLOCKS_PER_SEQUENCE is the variable you can finetune with. the workflow would be like:

  1. building engines: set maximum number of blocks per sequence (32) by TRTLLM_ENABLE_MMHA_MULTI_BLOCK_DEBUG=1 TRTLLM_MMHA_BLOCKS_PER_SEQUENCE=32.
  2. inference: finetune the number of blocks per sequence (4) by TRTLLM_ENABLE_MMHA_MULTI_BLOCK_DEBUG=1 TRTLLM_MMHA_BLOCKS_PER_SEQUENCE=4. it can be any integer <= 32.
littletomatodonkey commented 1 month ago

MBM does not work if input token length <= 4096, which is different from flash decoding?

the idea is that we will first try to fully utilize one SM before using more blocks per sequence, so there is a threshold to determine whether we need to enable the multi-block mode or not. you can always finetune the performance by setting TRTLLM_ENABLE_MMHA_MULTI_BLOCK_DEBUG=1 TRTLLM_MMHA_BLOCKS_PER_SEQUENCE=4, TRTLLM_MMHA_BLOCKS_PER_SEQUENCE is the variable you can finetune with. the workflow would be like:

  1. building engines: set maximum number of blocks per sequence (32) by TRTLLM_ENABLE_MMHA_MULTI_BLOCK_DEBUG=1 TRTLLM_MMHA_BLOCKS_PER_SEQUENCE=32.
  2. inference: finetune the number of blocks per sequence (4) by TRTLLM_ENABLE_MMHA_MULTI_BLOCK_DEBUG=1 TRTLLM_MMHA_BLOCKS_PER_SEQUENCE=4. it can be any integer <= 32.

Thanks for your reply, i'll try now and give a feedback.

littletomatodonkey commented 1 month ago

Hi, @PerkzZheng I tested the model with the env, but it seems that the inference cost is not cosistent with env TRTLLM_MMHA_BLOCKS_PER_SEQUENCE. Does that make sense?

PerkzZheng commented 1 month ago

@littletomatodonkey splitting one sequence into more blocks doesn't mean you will get more speedups. More blocks would have more reduction overhead, and more waves if you have already utilized full SMs.

littletomatodonkey commented 1 month ago

@littletomatodonkey splitting one sequence into more blocks doesn't mean you will get more speedups. More blocks would have more reduction overhead, and more waves if you have already utilized full SMs.

Then how can i know whether i fully utilized full SMs and what's the best practice of multi_block_mode in TensorRT-LLM? Thanks !

PerkzZheng commented 1 month ago

@littletomatodonkey splitting one sequence into more blocks doesn't mean you will get more speedups. More blocks would have more reduction overhead, and more waves if you have already utilized full SMs.

Then how can i know whether i fully utilized full SMs and what's the best practice of multi_block_mode in TensorRT-LLM? Thanks !

take H100-SXM as an example, you have 132 SMs, and let us say the batch size is 1, num heads is 16, then normally we can split the sequence into (132/16 = 8) blocks to fully utilize all SMs, but if the sequence length is quite small like 1K, it might not worth 8 blocks per sequence (maybe fewer).

nv-guomingz commented 3 weeks ago

Please reopen this ticket if there's further discussion.