NVIDIA / TensorRT-LLM

TensorRT-LLM provides users with an easy-to-use Python API to define Large Language Models (LLMs) and build TensorRT engines that contain state-of-the-art optimizations to perform inference efficiently on NVIDIA GPUs. TensorRT-LLM also contains components to create Python and C++ runtimes that execute those TensorRT engines.
https://nvidia.github.io/TensorRT-LLM
Apache License 2.0
8.71k stars 996 forks source link

Medusa performance degrades with batch size larger than 1 #2482

Open SoundProvider opened 7 hours ago

SoundProvider commented 7 hours ago

I'm trying to use medusa with trt-llm, referencing this page

It's working fine with vicuna 7B and its medusa heads, as reference in the example page.

In the example, it's stated that Note: Increasing the batch size may have a negative impact on performance My understanding is that, when the batch size increases, each sequence should wait for the other sequences to reach its position, resulting performance degradation.

But when I tested with vicuna 7B, the performance still dropped with 4 batch, each sequence using the same input. This is contradicting from my understanding.

Image I tested batch size variation with same inputs(4batch with same inputs)

What would be the reason?? It would be really nice if someone could explain.

Thank you

hello-11 commented 3 hours ago

@SoundProvider could you tell me the method of your performance evaluations?

SoundProvider commented 25 minutes ago

@hello-11 hello. I used the run script in the medusa example folder

python /app/tensorrt_llm/examples/run.py --engine_dir /app/models/medusa_test_3b/tensorrt_llm/4-gpu \
                                            --tokenizer_dir /app/models/vicuna-33b-v1.3 \
                                            --max_output_len=500 \
                                            --medusa_choices="[[0], [0, 0], [1], [0, 1], [2], [0, 0, 0], [1, 0], [0, 2], [3], [0, 3], [4], [0, 4], [2, 0], [0, 5], [0, 0, 1], [5], [0, 6], [6], [0, 7], [0, 1, 0], [1, 1], [7], [0, 8], [0, 0, 2], [3, 0], [0, 9], [8], [9], [1, 0, 0], [0, 2, 0], [1, 2], [0, 0, 3], [4, 0], [2, 1], [0, 0, 4], [0, 0, 5], [0, 0, 0, 0], [0, 1, 1], [0, 0, 6], [0, 3, 0], [5, 0], [1, 3], [0, 0, 7], [0, 0, 8], [0, 0, 9], [6, 0], [0, 4, 0], [1, 4], [7, 0], [0, 1, 2], [2, 0, 0], [3, 1], [2, 2], [8, 0], [0, 5, 0], [1, 5], [1, 0, 1], [0, 2, 1], [9, 0], [0, 6, 0], [0, 0, 0, 1], [1, 6], [0, 7, 0]]" \
                                            --temperature 1.0 \
                                            --input_text "Once upon" \
                                            --run_profiling