Open comaniac opened 10 months ago
A bit update on this issue:
I found that this could be case by case based on tokenizers. The normalizer in some tokenizers include spaces when decoding a single token. For example, here is the result from Falcon tokenizer:
falcon_tokenizer.decode(21956)
" particle" # It has a space
And this is from Llama tokenizer:
llama_tokenizer.decode(16445)
"particle" # No space
If we instead just use .convert_ids_tokens
to get the raw string without normalizer:
# Falcon
falcon_tokenizer.convert_ids_to_tokens(21956)
'Ġparticle'
# Llama
llama_tokenizer.convert_ids_to_tokens(16445)
'▁particle'
Although it's possible to buffer such decoded tokens and manually apply normalization at the client side, it's still tedious because the client won't know which symbol represents spaces.
We have added a Python BLS model called tensorrt_llm_bls
: https://github.com/triton-inference-server/tensorrtllm_backend/blob/main/all_models/inflight_batcher_llm/tensorrt_llm_bls/1/model.py that can be used to replace the ensemble model. The Python BLS model has an optional parameter accumulate_tokens
that can be set to true
to accumulate tokens before calling the post-processor. This should fix the tokenizer decoding issues you mentioned above. Note that in that case, every streaming response will contain the full text up to the latest token.
I'm trying to serve a Llama-2-70b-chat-hf model using Triton inferencer server with TRT-LLM engine. The script I used is
tools/inflight_batcher_llm/end_to_end_streaming_client.py
:This script streams the generated tokens in byte. I changed the callback function so that it would print strings:
However, the output becomes:
We can see that the spaces are gone. This is because the postprocess model in ensemble model decodes tokens one-by-one. In order to have correct spacing, we should do
tokenizer.decode(accumulated_tokens)
instead oftokenizer.decode(this_token)
, and only output the delta texts in postprocess model. However, I have no idea how to maintain the status in the postprocess model as all models in the ensemble model forms a single forward, stateless function.One solution I could think of is removing the postprocess from ensemble model and let the client use tokenizer to decode the tokens. However, this doesn't make sense because it requires the client to know and load the tokenizer of the model it talks to.