NVIDIA / TensorRT-LLM

TensorRT-LLM provides users with an easy-to-use Python API to define Large Language Models (LLMs) and build TensorRT engines that contain state-of-the-art optimizations to perform inference efficiently on NVIDIA GPUs. TensorRT-LLM also contains components to create Python and C++ runtimes that execute those TensorRT engines.
https://nvidia.github.io/TensorRT-LLM
Apache License 2.0
8.29k stars 925 forks source link

Degraded Performance when in-flight batching being used #1823

Closed TheCodeWrangler closed 3 months ago

TheCodeWrangler commented 3 months ago

System Info

Debian 11

nvidia-smi

+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.161.07             Driver Version: 535.161.07   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA L4                      Off | 00000000:00:03.0 Off |                    0 |
| N/A   75C    P0              62W /  72W |  20585MiB / 23034MiB |     75%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
|   1  NVIDIA L4                      Off | 00000000:00:04.0 Off |                    0 |
| N/A   75C    P0              66W /  72W |  20585MiB / 23034MiB |     76%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+

nvcc --version

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2023 NVIDIA Corporation
Built on Wed_Nov_22_10:17:15_PST_2023
Cuda compilation tools, release 12.3, V12.3.107
Build cuda_12.3.r12.3/compiler.33567101_0

Who can help?

@kaiyux

Information

Tasks

Reproduction

I am seeking to use a set of LoRa Weights (trained with linear 1.75 rope scaling and a 875000 rotary base) on a llama3-8B base model. I am planning to deploy to 2X L4 GPUs and would like to support 14,000

compiled the rel branch of triton inference server tensorrt-llm backend (also uses rel of tensorrt-llm). I have been approaching this path to ensure that the container I will serve using is identical to the one I use for compilation.

git clone https://github.com/triton-inference-server/tensorrtllm_backend.git
cd tensorrt_llm__backend
git checkout b92bdd79b6c50fb67203b6064e73662163012fe3

git lfs install
git submodule update --init --recursive
# Build Container.  Will be container for serving as well as inference
DOCKER_BUILDKIT=1 docker build -t triton_trt_llm:b92bdd79 -f dockerfile/Dockerfile.trt_llm_backend .

I am updating the config.json file within the LLama3B base model for the rope scaline parameters used in training the LoRa adapters:

...
# Update rope_scaling key in config.json to {"type": "dynamic", "factor": 1.5}
config_path = os.path.join(BASE_MODEL_DIR, "config.json")
with open(config_path, "r") as f:
    config = json.load(f)
config["rope_scaling"] = {"type": "dynamic", "factor": 1.75}
config["rope_theta"] = 875000
with open(config_path, "w") as f:
    json.dump(config, f)

I am then using this container to compile a llama3-8B base model for tensor parallelism 2 using the following convert/build commands.

python3 /app/tensorrt_llm/examples/llama/convert_checkpoint.py \
--model_dir ${BASE_MODEL_DIR} \
--output_dir /converted_base_model \
--rotary_base 875000 \
--dtype bfloat16 \
--tp_size 2

trtllm-build \
--max_input_len=14000 \
--max_num_tokens=14000 \
--max_seq_len=14000 \
--tp_size 2 \
--max_batch_size 4 \
--max_beam_width 3 \
--lora_plugin bfloat16 \
--gemm_plugin bfloat16 \
--lora_target_modules attn_q attn_k attn_v attn_dense mlp_h_to_4h mlp_gate mlp_4h_to_h \
--max_lora_rank 32 \
--gpt_attention_plugin bfloat16 \
--paged_kv_cache enable \
--multi_block_mode enable \
--remove_input_padding enable \
--checkpoint_dir /converted_base_model \
--use_custom_all_reduce enable \
--cluster_key L4 \
--workers=2 \
--use_paged_context_fmha enable \
--context_fmha enable \
--lookup_plugin bfloat16 \
--enable_xqa enable \
--output_dir ${ENGINE_DIR}

I have additional conversions to make my lora base weights into warmup files which i am using to initialize my lora weights. Leaving out these details here (though I might make a PR to provide them in the backend repo)

I then start my inference server and warmup runs successfully.

When I send sequential single inference traffic all adapters produce results of high quality. When I run several concurrent requests (and begin utilizing in-flight batching) the results degrade. The same input when run as the only thing in flight will give different results than if it is running while other inferences are in-flight.

Expected behavior

Inference results are deterministic (beam size 3 and I am passing random seed as well) and do not change when in flight batching active.

actual behavior

Results are only deterministic if it is the only inference in flight.

additional notes

I am willing to repost in the https://github.com/triton-inference-server/tensorrtllm_backend repo if the root cause is in that code.

TheCodeWrangler commented 3 months ago

I have a tried a very similar process using the v0.9.0 tag (saw same results as above)

I have also tried with the two latest commits to the main branch (though there were some changes required to my convert/compile args). In the main branches I would get backend exceptions, so have settled with the rel branch for my Issue as it seems closest to working.

TheCodeWrangler commented 3 months ago

I have recompiled with --max_batch_size 1 and it appears to have resolved my issue but reduces my throughput significantly

I actually have always been a bit unclear on the interaction of batch size with inflight-fused-batching. Any light you could shed on the interaction would be appreciated.

hijkzzz commented 3 months ago

could try disabling use_custom_all_reduce and use trtllm 0.10 or pip install tensorrt_llm== 0.11.0.dev2024061800?

TheCodeWrangler commented 3 months ago

Disabling use_custom_all_reduce fixed the issue!

TheCodeWrangler commented 3 months ago

Have not tried with newer images