triton-inference-server / tensorrtllm_backend

The Triton TensorRT-LLM Backend
Apache License 2.0
588 stars 81 forks source link

Limited batched streaming when using inflight batching #404

Closed vnkc1 closed 2 months ago

vnkc1 commented 2 months ago

System Info

p4d (8 x A100 40 GB GPUs)

TensorRTLLM: 0.8.0 release Triton server: 24.02

Who can help?

@kaiyux @byshiue

Information

Tasks

Reproduction

  1. Create checkpoint python ./TensorRT-LLM/llama/convert_checkpoint.py --model_dir ./Llama-2-13b-chat-hf/ --output_dir ./checkpoint --dtype float16 --tp_size 4 --workers 8

  2. Build engine trtllm-build --checkpoint_dir ./checkpoint --output_dir ./engine --gemm_plugin float16 --remove_input_padding enable --paged_kv_cache enable --context_fmha enable --workers 8 --max_batch_size 32 --max_input_len 4096 --max_output_len 512

  3. Load into Triton inference server in inflight batching mode https://github.com/triton-inference-server/tensorrtllm_backend/tree/v0.7.1/all_models/inflight_batcher_llm

  4. Test batched streaming inference (please substitute get_prompt function with test prompt)

import queue
import functools
import numpy as np
import tritonclient.grpc as grpcclient
from tritonclient.utils import np_to_triton_dtype, InferenceServerException

def prepare_tensor(name, input):
    t = grpcclient.InferInput(name, input.shape, np_to_triton_dtype(input.dtype))
    t.set_data_from_numpy(input)
    return t

CONCURRENCY = 4
INPUT_TOKENS = 3000
MAX_TOKENS = 40
client = grpcclient.InferenceServerClient(url="localhost:8001")

# get_prompt function returns a prompt of input size 3000
prompts: list[str] = [get_prompt(INPUT_TOKENS)] * CONCURRENCY
inputs = [
    [
        prepare_tensor("text_input", np.array([[prompt]], dtype=object)),
        prepare_tensor("max_tokens", np.array([[MAX_TOKENS]], dtype=np.int32)),
        prepare_tensor("stream", np.array([[True]], dtype=bool)),
    ]
    for prompt in prompts
]

# Start stream connection
response_queue = queue.Queue()
callback = lambda queue, result, error: queue.put(error if error else result)
client.start_stream(callback=functools.partial(callback, response_queue))

# Send requests
for index, inp in enumerate(inputs):
    client.async_stream_infer(model_name="ensemble", inputs=inp, request_id=str(index))

# Fetch responses
while True:
    try:
        response = response_queue.get(timeout=1)
        if type(response) != InferenceServerException:
            request_id = response.get_response().id
            print(request_id, end=", ", flush=True)
    except queue.Empty:
        break

# Stop stream
client.stop_stream()

Expected behavior

Tokens for multiple requests should be streamed out simultaneously

Example output (request ids): 0, 1, 3, 2, 4, 6, 5, 0, 1, 5, 9, 3, 2, 20, 10, 4, 6, 7, 8, 11, 12, 5, 9, 13, 14, 20, 10, 23, 14, 23, 15, 16, 15, 16, 5, 9, 13, 17, 18, 19, 21, 22, 0, 1, 3, 1, 3, 1, 3, 2, 4, 6, 4, 6, 4, 6, 7, 8, 11, 1, 3, 12, 10, 20, 10.....

actual behavior

Tokens for each request streamed one after the other

Output (request ids): 0, 0, 0, ..... 1, 1, 1, 1, 1, ..... 2, 2, 2, 2, .....

additional notes

For fewer input tokens (32), the streaming behavior is as expected. This behavior is only seen for larger input sizes.

Saigut commented 2 months ago

In my case, the output words of every request are returned together, not one by one. And the callback only be called one time for every request, this make it not like streaming.

vnkc1 commented 2 months ago

@Saigut could you expand more about the callback issue?

I see that the end to end grpc client example, which allows streaming, uses the same callback: https://github.com/triton-inference-server/tensorrtllm_backend/blob/main/inflight_batcher_llm/client/end_to_end_grpc_client.py#L31-L35

Saigut commented 2 months ago

@vnkc1 The issue I met is my fault, I didn't provide the "stream" INPUT to server.