NVIDIA / TensorRT-LLM

TensorRT-LLM provides users with an easy-to-use Python API to define Large Language Models (LLMs) and build TensorRT engines that contain state-of-the-art optimizations to perform inference efficiently on NVIDIA GPUs. TensorRT-LLM also contains components to create Python and C++ runtimes that execute those TensorRT engines.
https://nvidia.github.io/TensorRT-LLM
Apache License 2.0
8.36k stars 940 forks source link

Enc-Dec C++ Runtime Paged KV - Inflight Batching output junks while inference with multiple input texts #1753

Open thanhlt998 opened 4 months ago

thanhlt998 commented 4 months ago

I try inference my T5 model with C++ runtime used Paged KV at the commit b777bd64750abf30ca7eda48e8b6ba3c5174aafd. Its result is normal when inference with single input text, but with multiple input texts the outputs are something weird.

My T5 model config:

{
  "architectures": [
    "T5ForConditionalGeneration"
  ],
  "classifier_dropout": 0.0,
  "d_ff": 3840,
  "d_kv": 64,
  "d_model": 1536,
  "decoder_start_token_id": 0,
  "dense_act_fn": "silu",
  "dropout_rate": 0.1,
  "eos_token_id": 1,
  "extra_ids": 256,
  "feed_forward_proj": "gated-silu",
  "initializer_factor": 1.0,
  "is_encoder_decoder": true,
  "is_gated_act": true,
  "layer_norm_epsilon": 1e-06,
  "model_type": "t5",
  "n_positions": 2048,
  "num_decoder_layers": 24,
  "num_heads": 24,
  "num_layers": 24,
  "output_past": true,
  "pad_token_id": 0,
  "relative_attention_max_distance": 128,
  "relative_attention_num_buckets": 32,
  "tie_word_embeddings": false,
  "torch_dtype": "float32",
  "transformers_version": "4.39.1",
  "use_cache": true,
  "vocab_size": 65664
}

I followed the README at enc-dec example folder:

convert checkpoint

export MODEL_DIR="path_to_t5" # or "flan-t5-small"
export MODEL_NAME="myt5"
export MODEL_TYPE="t5"

export INFERENCE_PRECISION="float16"
export TP_SIZE=1
export PP_SIZE=1
export WORLD_SIZE=1
export MAX_BEAM_WIDTH=1
python tensorrt_llm/examples/enc_dec/convert_checkpoint.py --model_type ${MODEL_TYPE} \
                --model_dir "${MODEL_DIR}" \
                --output_dir "${MODEL_DIR}/trt_models/${INFERENCE_PRECISION}/tp${TP_SIZE}/pp${PP_SIZE}" \
                --tp_size ${TP_SIZE} \
                --pp_size ${PP_SIZE} \
                --dtype ${INFERENCE_PRECISION} \
                --workers 1

build engine

export MODEL_DIR="path_to_t5_model" # or "flan-t5-small"
export MODEL_NAME="myt5"
export MODEL_TYPE="t5"
export INFERENCE_PRECISION="float16"
export TP_SIZE=1
export PP_SIZE=1
export WORLD_SIZE=1
export MAX_BEAM_WIDTH=1
export MAX_BATCH_SIZE=1
export OUTPUT_DIR="triton_model_repos/${MODEL_NAME}/tensorrt_llm/1"

trtllm-build --checkpoint_dir "${MODEL_DIR}/trt_models/${INFERENCE_PRECISION}/tp${TP_SIZE}/pp${PP_SIZE}/encoder" \
                --output_dir "${OUTPUT_DIR}/${WORLD_SIZE}-gpu/${INFERENCE_PRECISION}/tp${TP_SIZE}/encoder" \
                --paged_kv_cache enable \
                --moe_plugin disable \
                --enable_xqa disable \
                --use_custom_all_reduce disable \
                --max_beam_width ${MAX_BEAM_WIDTH} \
                --max_batch_size 2 \
                --max_input_len 2048 \
                --max_encoder_input_len 2048 \
                --max_output_len 2048 \
                --gemm_plugin ${INFERENCE_PRECISION} \
                --bert_attention_plugin ${INFERENCE_PRECISION} \
                --gpt_attention_plugin ${INFERENCE_PRECISION} \
                --remove_input_padding enable \
                --context_fmha disable

# For decoder, refer to the above content and set --max_input_len correctly
trtllm-build --checkpoint_dir "${MODEL_DIR}/trt_models/${INFERENCE_PRECISION}/tp${TP_SIZE}/pp${PP_SIZE}/decoder" \
                --output_dir "${OUTPUT_DIR}/${WORLD_SIZE}-gpu/${INFERENCE_PRECISION}/tp${TP_SIZE}/decoder" \
                --paged_kv_cache enable \
                --moe_plugin disable \
                --enable_xqa disable \
                --use_custom_all_reduce disable \
                --max_beam_width ${MAX_BEAM_WIDTH} \
                --max_batch_size 2 \
                --max_output_len 2048 \
                --max_encoder_input_len 2048 \
                --gemm_plugin ${INFERENCE_PRECISION} \
                --bert_attention_plugin ${INFERENCE_PRECISION} \
                --gpt_attention_plugin ${INFERENCE_PRECISION} \
                --remove_input_padding enable \
                --context_fmha disable \
                --max_input_len 1

Run C++ runtime with the built engine:

1st try

command

python tensorrt_llm/examples/run.py \
--engine_dir /tensorrtllm_backend/triton_model_repos/ul2sxl/tensorrt_llm/1/1-gpu/float16/tp1 \
--tokenizer_dir /tensorrtllm_backend/data/models/ul2_sxl_20240419/checkpoint_179000 \
--max_output_len 256 \
--input_text \
"Mẹ Bác Hồ là ai" \
"Bác Hồ là ai" \
"Diện tích của Việt Nam là bao nhiêu?"

output

[TensorRT-LLM] TensorRT-LLM version: 0.11.0.dev2024060400
[06/07/2024-16:08:34] [TRT-LLM] [W] This path is an encoder-decoder model. Using different handling.
Input [Text 0]: "<pad>"
Output [Text 0 Beam 0]: "Mẹ Bác Hồ là bà Nguyễn Thị Minh Khai, người phụ nữ Việt Nam đầu tiên được Bác Hồ chọn làm mẹ nuôi. Bà Nguyễn Thị Minh Khai là một trong những người phụ nữ Việt Nam đầu tiên được Bác Hồ chọn làm mẹ nuôi. Bà Nguyễn Thị Minh Khai là một trong những người phụ nữ Việt Nam đầu tiên được Bác Hồ chọn làm mẹ nuôi."
Input [Text 1]: "<pad>"
Output [Text 1 Beam 0]: "Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ Mẹ"
Input [Text 2]: "<pad>"
Output [Text 2 Beam 0]: "<extra_id_0> 10[NEWLINE]-<extra_id_1>[NEWLINE]-<extra_id_2>-[NEWLINE]-<extra_id_3>-[NEWLINE]-<extra_id_4>-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]"

2nd try: just change the order of input texts

command

python tensorrt_llm/examples/run.py \
--engine_dir /tensorrtllm_backend/triton_model_repos/ul2sxl/tensorrt_llm/1/1-gpu/float16/tp1 \
--tokenizer_dir /tensorrtllm_backend/data/models/ul2_sxl_20240419/checkpoint_179000 \
--max_output_len 256 \
--input_text \
"Bác Hồ là ai" \
"Mẹ Bác Hồ là ai" \
"Diện tích của Việt Nam là bao nhiêu?"

output

[TensorRT-LLM] TensorRT-LLM version: 0.11.0.dev2024060400
[06/07/2024-16:23:48] [TRT-LLM] [W] This path is an encoder-decoder model. Using different handling.
Input [Text 0]: "<pad>"
Output [Text 0 Beam 0]: "Bác Hồ là một nhà cách mạng, nhà lãnh đạo cách mạng, và là người sáng lập và rèn luyện Đảng Cộng sản Việt Nam. Bác Hồ sinh ngày 19/5/1890 tại làng Kim Liên, xã Nam Đàn, huyện Nam Đàn, tỉnh Nghệ An. Bác Hồ là người sáng lập và rèn luyện Đảng Cộng sản Việt Nam, và là người lãnh đạo cách mạng Việt Nam trong suốt cuộc đời. Bác Hồ đã có nhiều đóng góp quan trọng cho sự phát triển của cách mạng Việt Nam, bao gồm việc thành lập Đảng Cộng sản Việt Nam, lãnh đạo cách mạng Việt Nam trong cuộc kháng chiến chống Pháp và chống Mỹ, và đưa ra các chính sách quan trọng để giải phóng dân tộc và xây dựng đất nước."
Input [Text 1]: "<pad>"
Output [Text 1 Beam 0]: "<extra_id_0> 10[NEWLINE]-<extra_id_1>[NEWLINE]-<extra_id_2>-[NEWLINE]-<extra_id_3>-[NEWLINE]-<extra_id_4>-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]"
Input [Text 2]: "<pad>"
Output [Text 2 Beam 0]: "<extra_id_0> 10[NEWLINE]-<extra_id_1>[NEWLINE]-<extra_id_2>-[NEWLINE]-<extra_id_3>-[NEWLINE]-<extra_id_4>-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]-[NEWLINE]"

May it be some bugs in the release of C++ runtime + inflight batching for Enc-Dec model?

symphonylyh commented 3 months ago

just reproduced on our end. Investigating now

symphonylyh commented 3 months ago

Hi @thanhlt998 , can you post GPU specs? Somehow we can randomly reproduce this on NVLink H100, but not on PCIe H100? Are you using a NVL machine?

thanhlt998 commented 3 months ago

Hi @symphonylyh , I am using one NVIDIA GeForce RTX 2080 Ti GPU for my experiment.

symphonylyh commented 3 months ago

@thanhlt998 fixed. It was due to missing cuda stream synchronization between encoder stream and decoder stream. The fix will be released in next week's weekly main branch update

thanhlt998 commented 3 months ago

@symphonylyh, thanks for your support!

thanhlt998 commented 3 months ago

@symphonylyh, I found the latest PR merged yesterday. Was the fix included in that PR?

owenonline commented 2 months ago

@thanhlt998 When I attempt to do this the model runner seems to look directly in the engine directory for the config files rather than in engine_dir/encoder and engine_dir/decoder. What does the config.json file you have located directly in your engine_dir look like?

0xd8b commented 2 months ago

@thanhlt998 fixed. It was due to missing cuda stream synchronization between encoder stream and decoder stream. The fix will be released in next week's weekly main branch update

For this issue, if I want to quickly modify the code, which part should I change? I look forward to your reply.