NVIDIA / TensorRT-LLM

TensorRT-LLM provides users with an easy-to-use Python API to define Large Language Models (LLMs) and build TensorRT engines that contain state-of-the-art optimizations to perform inference efficiently on NVIDIA GPUs. TensorRT-LLM also contains components to create Python and C++ runtimes that execute those TensorRT engines.
https://nvidia.github.io/TensorRT-LLM
Apache License 2.0
8.7k stars 994 forks source link

Top-P sampling occasionally produces invalid tokens #1590

Closed AlessioNetti closed 6 days ago

AlessioNetti commented 6 months ago

System Info

Who can help?

@byshiue

Information

Tasks

Reproduction

We noticed that TensorRT-LLM occasionally (~0.01% of requests) generates invalid tokens. The issue can be reproduced using a generic Falcon 7B model via the following:

python convert_checkpoint.py --model_dir ./falcon_7b_tp1_instruct/ --dtype bfloat16 --output_dir ./falcon_7b_tp1_instruct_trt_chkpt

trtllm-build --checkpoint_dir ./falcon_7b_tp1_instruct_trt_chkpt/ --gemm_plugin bfloat16 --remove_input_padding enable --gpt_attention_plugin bfloat16 --output_dir ./falcon_7b_tp1_instruct_p200_g200 --gather_all_token_logits --max_input_len 200 --max_output_len 200 --max_batch_size 64

python example_basic.py --model_path ./falcon_7b_tp1_instruct_p200_g200

The examples/bindings/executor/example_basic.py script was modified to issue random top-P requests (in batches of 16) until an invalid token is detected in the output. The changes are as in the following:

diff --git a/examples/bindings/executor/example_basic.py b/examples/bindings/executor/example_basic.py
index 2c7a3fc..65a9b57 100644
--- a/examples/bindings/executor/example_basic.py
+++ b/examples/bindings/executor/example_basic.py
@@ -1,4 +1,6 @@
 import argparse
+import torch
+import random

 import tensorrt_llm.bindings.executor as trtllm

@@ -20,16 +22,25 @@ if __name__ == "__main__":
                                trtllm.ExecutorConfig(1))

     if executor.can_enqueue_requests():
-        # Create the request.
-        request = trtllm.Request(input_token_ids=[1, 2, 3, 4],
-                                 max_new_tokens=10)
-
-        # Enqueue the request.
-        request_id = executor.enqueue_request(request)
-
-        # Wait for the new tokens.
-        responses = executor.await_responses(request_id)
-        output_tokens = responses[0].result.output_token_ids
-
-        # Print tokens.
-        print(output_tokens)
+        while True:
+            # Create the request.
+            requests = []
+            for _ in range(16):
+                input_token_ids = [random.randint(100, 10000) for _ in range(200)]
+                requests.append(trtllm.Request(input_token_ids=input_token_ids, max_new_tokens=200,
+                                               sampling_config=trtllm.SamplingConfig(top_p=0.5, top_k=None, temperature=20.0)))
+
+            # Enqueue the request.
+            request_ids = executor.enqueue_requests(requests)
+
+            # Wait for the new tokens.
+            responses = executor.await_responses(request_ids)
+            
+            for idx, re in enumerate(responses):
+                output_tokens = re[0].result.output_token_ids[0]
+                valid_output = all(el >= 0 and el < 200000 for el in output_tokens)
+                if not valid_output:
+                    print(f"Output tokens : {output_tokens[200:]}")
+                    exit(-1)
+                else:
+                    print(f"Valid output produced for request {request_ids[idx]}.")

Expected behavior

Requests should always generate valid tokens, that are in the [0, vocabulary_size) range.

actual behavior

Occasionally, requests will produce invalid tokens that are outside of the model's vocabulary size. Below is an example of the issue under our custom example_basic.py script:

Valid output produced for request 9534.
Valid output produced for request 9535.
Valid output produced for request 9536.
Output tokens : [47796, 54241, 47783, 58101, 6674, 23726, 23592, 42594, 6139, 25248, 52039, 47238, 46481, 59789, 36977, 9214, 30383, 31047, 19853, 59072, 25294, 63500, 59925, 44334, 38232, 28210, 38889, 26873, 35512, 48818, 38165, 14048, 49025, 30020, 59300, 49636, 5338, 63956, 4748, 22356, 26041, 19883, 22013, 32389, 24446, 36715, 11451, 13325, 58318, 29675, 12733, 15128, 323, 26868, 42477, 28018, 18622, 52692, 60096, 19486, 3727, 1427, 32693, 18763, 38281, 38747, 52358, 58497, 17945, 36842, 9453, 23113, 21691, 22407, 9894, 27278, 8361, 40261, 2147483647, 18931, 38614, 47912, 48115, 36611, 33955, 41329, 45530, 23243, 43669, 10268, 19238, 6055, 49515, 63961, 29434, 48151, 54508, 25936, 55805, 10214, 28366, 22400, 7200, 17613, 30007, 16812, 1529, 62540, 63633, 7331, 58970, 46938, 25656, 52488, 11953, 32571, 13142, 61313, 9385, 49280, 43718, 47734, 27930, 3368, 56759, 41270, 23886, 32473, 48038, 12786, 39043, 4837, 16915, 2584, 16430, 56707, 46255, 26404, 33055, 51739, 14011, 18179, 25129, 7630, 62620, 11823, 51429, 7700, 17108, 7422, 9389, 9999, 32405, 36641, 6937, 13023, 29698, 60332, 10098, 46336, 54260, 41558, 32326, 7579, 58826, 2443, 12843, 38563, 51635, 63544, 10124, 2484, 43080, 16858, 24803, 3017, 42640, 46269, 22102, 53352, 51123, 42491, 55109, 27590, 2322, 28774, 9365, 19873, 1538, 64635, 8407, 63458, 49056, 53777, 5887, 16413, 5956, 36375, 42348, 27573]

As it can be seen, one of the tokens is 2147483647. In other instances we have also observed negative tokens, but always in the billions range - this would suggest an integer overflow issue connected to top-P sampling logic somewhere.

additional notes

byshiue commented 6 months ago

Thank you. I can reproduce the issue. I little change the basic_example to help accelerating the reproducing.

import argparse
import torch
import random

import tensorrt_llm.bindings.executor as trtllm

# This example hows to use the python bindings to create an executor, enqueue a
# request, and get the generated tokens.

# First, follow the steps in README.md to generate the engines.

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Executor Bindings Example")
    parser.add_argument("--model_path",
                        type=str,
                        required=True,
                        help="Directory containing model engine")
    args = parser.parse_args()

    # Create the executor.
    executor = trtllm.Executor(args.model_path, trtllm.ModelType.DECODER_ONLY,
                               trtllm.ExecutorConfig(1))

    random.seed(1234)
    if executor.can_enqueue_requests():
        ite_count = 0
        while True:
            # Create the request.
            requests = []
            ite_count += 16

            for _ in range(16):
                input_token_ids = [random.randint(100, 10000) for _ in range(200)]
                requests.append(trtllm.Request(input_token_ids=input_token_ids, max_new_tokens=105,
                                               sampling_config=trtllm.SamplingConfig(top_p=0.5, top_k=None, temperature=20.0)))
            if ite_count < 6616:
                continue

            # Enqueue the request.
            request_ids = executor.enqueue_requests(requests)

            # Wait for the new tokens.
            responses = executor.await_responses(request_ids)

            for idx, re in enumerate(responses):
                output_tokens = re[0].result.output_token_ids[0]
                valid_output = all(el >= 0 and el < 200000 for el in output_tokens)
                if not valid_output:
                    print(f"InValid output produced for request {request_ids[idx]}.")
                    print(f"Output tokens : {output_tokens[200:]}")
                    exit(-1)
                else:
                    print(f"Valid output produced for request {request_ids[idx]}.")

We are still investigating the reason.

ChristinaZ commented 6 months ago

Hi Alessio, Thank you for finding this bug. We are looking into this issue. In case this bug becomes a bottleneck in your workflow, one workaround is to change the value of variable mIsAirTopP to false, TRT-LLM will adopt another top-p sampling method. We will try to fix the bug as soon as possible.

nv-guomingz commented 6 days ago

Hi @AlessioNetti do u still have further issue or question now? If not, we'll close it soon.

AlessioNetti commented 6 days ago

Hi - the bug has been fixed a few versions back, so we can close this.