triton-inference-server / tensorrtllm_backend

The Triton TensorRT-LLM Backend
Apache License 2.0
581 stars 81 forks source link

Performance Issue with return_context_logits Enabled in TensorRT-LLM #419

Open metterian opened 2 months ago

metterian commented 2 months ago

System Info

Intel(R) Xeon(R) CPU @ 2.20GHz Architecture: x86_64 NVIDIA A100-SXM4-40G Ubuntu

Who can help?

No response

Information

Tasks

Reproduction

I follow official examples for Llama model: https://github.com/NVIDIA/TensorRT-LLM/tree/v0.8.0/examples/llama

I've been experiencing significant slowdowns when the return_context_logits flag is turned on. For context, I am utilizing the llama example and have specifically enabled the gather_context_logits flag during the TensorRT-LLM build process.

Additionally, I have been passing return_context_logits through the triton_client in an attempt to retrieve logits for the request sentences. To accommodate this, I have set the request_output_len or output_len to 1.

Expected behavior

The anticipated behavior when enabling return_context_logits would be a manageable decrease in speed, ideally not significantly deviating from the throughput when the flag is off. Performance should ideally be on par with or better than the forward pass speed of HuggingFace implementations.

actual behavior

<!DOCTYPE html>

The current observed behavior shows an almost 8-fold decrease in execution speed when trying to obtain logits with a maximum length of 1. This is surprisingly slower than the forward pass speed of comparable HuggingFace models.

Here's a comparative table of performance with and without the return_context_logits flag:

Logit Status
max_gen_token input_len Execution Time Average Time per Example
On 1 2000 0:49 0.98s
Off 1 2000 0:06 0.12s

additional notes

I have executed the trtllm-build with the following configuration:

trtllm-build --checkpoint_dir {model_dir}/tensorrt/{tp_size}-gpu \
             --remove_input_padding enable \
             --gpt_attention_plugin float16 \
             --context_fmha enable \
             --gemm_plugin float16 \
             --output_dir {model_dir}/tensorrt_llm/context_fmha \
             --paged_kv_cache disable \
             --enable_xqa disable \
             --multi_block_mode disable \
             --use_custom_all_reduce disable \
             --tp_size {tp_size} \
             --workers {tp_size} \
             --max_batch_size 1 \
             --max_input_len 8192 \
             --max_output_len 8192 \
             --max_num_tokens 8192 \
             --gather_context_logits

Any insights or assistance in addressing this unexpected slowdown would be greatly appreciated. If there are any further experiments or specific areas you would recommend investigating, please advise.

yweng0828 commented 2 months ago

Hi @metterian , thanks for your feedback. Are the performance data you show based on triton? If so, could you please try to use only TRT-LLM (not based on triton) (preferably with warmup).

We expect that the overhead caused by this feature on TRT-LLM is limited and acceptable.