NVIDIA / TensorRT-LLM

TensorRT-LLM provides users with an easy-to-use Python API to define Large Language Models (LLMs) and build TensorRT engines that contain state-of-the-art optimizations to perform inference efficiently on NVIDIA GPUs. TensorRT-LLM also contains components to create Python and C++ runtimes that execute those TensorRT engines.
https://nvidia.github.io/TensorRT-LLM
Apache License 2.0
8.71k stars 996 forks source link

[bug] unnecessary batch logits post processor calls #2439

Open akhoroshev opened 1 week ago

akhoroshev commented 1 week ago

version

When I build model with paged_context_fmha = true and max_num_tokens = 4096, chunked context is enabled. I see that Executor calls batch_logit_processor more than one time for the first token.

To prove that I'm printing the number of tokens in callback (FusedLogitsProcessor::process is my implementation of callback).

I send request with different input size and set maxTokens to 3.

input_context_size: 18810

[TensorRT-LLM][ERROR] FusedLogitsProcessor::process, beamToken.size() 18810
[TensorRT-LLM][ERROR] FusedLogitsProcessor::process, beamToken.size() 18810
[TensorRT-LLM][ERROR] FusedLogitsProcessor::process, beamToken.size() 18810
[TensorRT-LLM][ERROR] FusedLogitsProcessor::process, beamToken.size() 18810
[TensorRT-LLM][ERROR] FusedLogitsProcessor::process, beamToken.size() 18810
[TensorRT-LLM][ERROR] FusedLogitsProcessor::process, beamToken.size() 18811
[TensorRT-LLM][ERROR] FusedLogitsProcessor::process, beamToken.size() 18812

input_context_size: 15014

[TensorRT-LLM][ERROR] FusedLogitsProcessor::process, beamToken.size() 15014
[TensorRT-LLM][ERROR] FusedLogitsProcessor::process, beamToken.size() 15014
[TensorRT-LLM][ERROR] FusedLogitsProcessor::process, beamToken.size() 15014
[TensorRT-LLM][ERROR] FusedLogitsProcessor::process, beamToken.size() 15014
[TensorRT-LLM][ERROR] FusedLogitsProcessor::process, beamToken.size() 15015
[TensorRT-LLM][ERROR] FusedLogitsProcessor::process, beamToken.size() 15016

input_context_size: 12585

[TensorRT-LLM][ERROR] FusedLogitsProcessor::process, beamToken.size() 12585
[TensorRT-LLM][ERROR] FusedLogitsProcessor::process, beamToken.size() 12585
[TensorRT-LLM][ERROR] FusedLogitsProcessor::process, beamToken.size() 12585
[TensorRT-LLM][ERROR] FusedLogitsProcessor::process, beamToken.size() 12585
[TensorRT-LLM][ERROR] FusedLogitsProcessor::process, beamToken.size() 12586
[TensorRT-LLM][ERROR] FusedLogitsProcessor::process, beamToken.size() 12587

input_context_size: 8176

[TensorRT-LLM][ERROR] FusedLogitsProcessor::process, beamToken.size() 8176
[TensorRT-LLM][ERROR] FusedLogitsProcessor::process, beamToken.size() 8176
[TensorRT-LLM][ERROR] FusedLogitsProcessor::process, beamToken.size() 8177
[TensorRT-LLM][ERROR] FusedLogitsProcessor::process, beamToken.size() 8178

You can see that first token logit callback is repeated ceil(input_context_size / max_num_tokens) times. In fact, the logits for calls to ceil(input_context_size/max_num_tokens) - 1 are ignored (sampling layers are not called) and Executor returns exactly 3 tokens (as expected). But it's very strange to run a logit processor for "garbage" logits.

akhoroshev commented 1 week ago

it would be great if you called logits post processor for request only if isLastContextChunk() || isGenerationInProgressState()

amukkara commented 1 week ago

@akhoroshev thanks for pointing this out.

we will make the change to invoke logits post processor only for the last context chunk.