NVIDIA / TensorRT-LLM

TensorRT-LLM provides users with an easy-to-use Python API to define Large Language Models (LLMs) and build TensorRT engines that contain state-of-the-art optimizations to perform inference efficiently on NVIDIA GPUs. TensorRT-LLM also contains components to create Python and C++ runtimes that execute those TensorRT engines.
https://nvidia.github.io/TensorRT-LLM
Apache License 2.0
8.43k stars 953 forks source link

In-flight batching and mixed batch #2311

Open huijjj opened 6 days ago

huijjj commented 6 days ago

Image According to the document it seems like the packed & mixed batch is the default behavior of TensorRT-LLM.

So I've conducted an experiment to see the effect of mixed batching by changing the max_num_tokens to appropriate size. Other setups are as follows:

And here are the results of max_num_tokens set to 2048, 2050, 2064 and 2176 respectively:

max_num_tokens avg TTFT (ms) avg TPOT (ms) avg token throuput (token/sec)
2048 265435.6 67.22 3323.19
2050 265873.7 67.00 3319.35
2064 268012.8 67.49 3291.81
2176 265234.8 66.75 3325.43

I expected max_num_tokens 2048 will batch only 2 summarization phase sequences whereas 2050, 2064 and 2176 will batch extra generation phase sequences up to 2, 16 and 128. The results show marginal change, and shows no clear evidence of mixed batching. Note that the TTFT results are quite broken as benchmark was done with request rate(or Query Per Second) set to infinite.

So, Q1: Is mixed batch(batching sequence in summarization phase and generation phase in a single batch) still enabled by default? Q2: If true, then why couldn't I see the difference and how should I design an experiment to see the impact of mixed batching. Q3: If not, how can I enable it, and what was the reason to drop it from the default option? Q4: Are there any materials including the code to get a better understanding of the batch manager(the request scheduler) in TensorRT-LLM, as for me, nothing seems to be open.

dcampora commented 19 hours ago

Hi there,

Yes, packed and mixed batch is still the default behaviour of TRTLLM. Let's consider the first case with 2048 tokens:

In the other cases, the ramp up is done slightly faster. Considering 2050 tokens:

So from your experiments one would expect the ramp up to be faster as the max_num_tokens gets higher. If you add many more samples I would expect this effect to become negligible though, which might explain why you don't see a difference.

The batch manager is not open source at the moment. The resources you mentioned are the ones that are available currently.

huijjj commented 4 hours ago

I initially thought that TensorRT-LLM prioritized prefills over decode, similar to vLLM. So I expected the schedule for a max token count of 2050 to look like this:

But if TensorRT-LLM behaves as you explained, it makes sense that I didn’t observe a difference in my experiment. Thanks for the clarification! I’ll run another experiment to validate my understanding.