triton-inference-server / tensorrtllm_backend

The Triton TensorRT-LLM Backend
Apache License 2.0
648 stars 93 forks source link

Accumulated decoding when streaming #34

Open comaniac opened 10 months ago

comaniac commented 10 months ago

I'm trying to serve a Llama-2-70b-chat-hf model using Triton inferencer server with TRT-LLM engine. The script I used is tools/inflight_batcher_llm/end_to_end_streaming_client.py:

python3 tools/inflight_batcher_llm/end_to_end_streaming_client.py -p "What is deep learning?" -S -o 64

This script streams the generated tokens in byte. I changed the callback function so that it would print strings:

print(output[0].decode(), flush=True, end="")

However, the output becomes:

Deeplearningisasubsetofmachinelearningthatinvolvestheuseofartificialneuralnetworkstomodelandsolvecomplexproblems.Inadeeplearningsystem,therearetypicallymultiplelayersofneuralnetworksthatprocessandtransformthedatainahierarchicalmanner.Eachlayerbuildsonthepreviousone,allowingthesystem

We can see that the spaces are gone. This is because the postprocess model in ensemble model decodes tokens one-by-one. In order to have correct spacing, we should do tokenizer.decode(accumulated_tokens) instead of tokenizer.decode(this_token), and only output the delta texts in postprocess model. However, I have no idea how to maintain the status in the postprocess model as all models in the ensemble model forms a single forward, stateless function.

One solution I could think of is removing the postprocess from ensemble model and let the client use tokenizer to decode the tokens. However, this doesn't make sense because it requires the client to know and load the tokenizer of the model it talks to.

comaniac commented 10 months ago

A bit update on this issue:

I found that this could be case by case based on tokenizers. The normalizer in some tokenizers include spaces when decoding a single token. For example, here is the result from Falcon tokenizer:

falcon_tokenizer.decode(21956)
" particle" # It has a space

And this is from Llama tokenizer:

llama_tokenizer.decode(16445)
"particle" # No space

If we instead just use .convert_ids_tokens to get the raw string without normalizer:

# Falcon
falcon_tokenizer.convert_ids_to_tokens(21956)
'Ġparticle'

# Llama
llama_tokenizer.convert_ids_to_tokens(16445)
'▁particle'

Although it's possible to buffer such decoded tokens and manually apply normalization at the client side, it's still tedious because the client won't know which symbol represents spaces.

pcastonguay commented 9 months ago

We have added a Python BLS model called tensorrt_llm_bls: https://github.com/triton-inference-server/tensorrtllm_backend/blob/main/all_models/inflight_batcher_llm/tensorrt_llm_bls/1/model.py that can be used to replace the ensemble model. The Python BLS model has an optional parameter accumulate_tokens that can be set to true to accumulate tokens before calling the post-processor. This should fix the tokenizer decoding issues you mentioned above. Note that in that case, every streaming response will contain the full text up to the latest token.