triton-inference-server / tensorrtllm_backend

The Triton TensorRT-LLM Backend
Apache License 2.0
581 stars 81 forks source link

Exception when disabling "inflight_fused_batching" #511

Open TheCodeWrangler opened 1 week ago

TheCodeWrangler commented 1 week ago

System Info

Debian 11

nvidia-smi

+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.161.07             Driver Version: 535.161.07   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA L4                      Off | 00000000:00:03.0 Off |                    0 |
| N/A   75C    P0              62W /  72W |  20585MiB / 23034MiB |     75%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
|   1  NVIDIA L4                      Off | 00000000:00:04.0 Off |                    0 |
| N/A   75C    P0              66W /  72W |  20585MiB / 23034MiB |     76%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+

nvcc --version

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2023 NVIDIA Corporation
Built on Wed_Nov_22_10:17:15_PST_2023
Cuda compilation tools, release 12.3, V12.3.107
Build cuda_12.3.r12.3/compiler.33567101_0

Who can help?

@kaiyux

Information

Tasks

Reproduction

I am working on an issue mentioned

https://github.com/NVIDIA/TensorRT-LLM/issues/1823

For this I am attempting to run without "inflight_fused_batching" by setting the "tensorrtllm" config.pbtxt to "v1"

parameters: {
  key: "gpt_model_type"
  value: {
    string_value: "v1"
  }
}

My setup will start and run when using inflight_fused_batching but upon initialization with v1 it results in:

TensorRT-LLM][WARNING] TrtGptModelType::V1 is deprecated and will be removed in a future release. Please use TrtGptModelType::InflightBatching or TrtGptModelType::InflightFusedBatching instead.
[TensorRT-LLM][INFO] TRTGptModel maxNumSequences: 4
[TensorRT-LLM][INFO] TRTGptModel maxBatchSize: 4
[TensorRT-LLM][INFO] TRTGptModel mMaxAttentionWindowSize: 15024
[TensorRT-LLM][INFO] TRTGptModel enableTrtOverlap: 0
[TensorRT-LLM][INFO] TRTGptModel normalizeLogProbs: 1
[TensorRT-LLM][INFO] Loaded engine size: 8172 MiB
[TensorRT-LLM][INFO] Detecting local TP group for rank 0
[TensorRT-LLM][INFO] Detecting local TP group for rank 1
[TensorRT-LLM][INFO] TP group is intra-node for rank 1
[TensorRT-LLM][INFO] TP group is intra-node for rank 0
[TensorRT-LLM][INFO] Allocated 2687.77 MiB for execution context memory.
[TensorRT-LLM][WARNING] GptSession is deprecated and will be removed in a future release. Please use the executor API instead (cpp/include/tensorrt_llm/executor).
[TensorRT-LLM][INFO] Allocated 2687.77 MiB for execution context memory.
[TensorRT-LLM][WARNING] GptSession is deprecated and will be removed in a future release. Please use the executor API instead (cpp/include/tensorrt_llm/executor).
[TensorRT-LLM][INFO] [MemUsageChange] TensorRT-managed allocation in IExecutionContext creation: CPU +0, GPU +0, now: CPU 0, GPU 8163 (MiB)
[TensorRT-LLM][INFO] [MemUsageChange] TensorRT-managed allocation in IExecutionContext creation: CPU +0, GPU +0, now: CPU 0, GPU 8163 (MiB)
[TensorRT-LLM][INFO] Max KV cache pages per sequence: 235
[TensorRT-LLM][INFO] Max KV cache pages per sequence: 235
[TensorRT-LLM][INFO] Max tokens in paged KV cache: 130624. Allocating 8560574464 bytes.
[TensorRT-LLM][INFO] Max tokens in paged KV cache: 130624. Allocating 8560574464 bytes.
[TensorRT-LLM][WARNING] cancellation_check_period_ms is not specified, will be set to 100 (ms)
[TensorRT-LLM][WARNING] stats_check_period_ms is not specified, will be set to 100 (ms)
[TensorRT-LLM][WARNING] cancellation_check_period_ms is not specified, will be set to 100 (ms)
[TensorRT-LLM][WARNING] stats_check_period_ms is not specified, will be set to 100 (ms)
[TensorRT-LLM][ERROR] Encountered an error in forwardAsync function: Input tensor 'attn_q_lora_weights_pointers_0' not found; expected shape: (-1, 2) (/app/tensorrt_llm/cpp/tensorrt_llm/runtime/tllmRuntime.cpp:151)
1       0x7f3b64075d4b tensorrt_llm::runtime::TllmRuntime::setInputTensors(int, std::unordered_map<std::string, std::shared_ptr<tensorrt_llm::runtime::ITensor>, std::hash<std::string>, std::equal_to<std::string>, std::allocator<std::pair<std::string const, std::shared_ptr<tensorrt_llm::runtime::ITensor> > > > const&) + 651
2       0x7f3b64023ab8 tensorrt_llm::runtime::GptSession::executeContextStep(std::vector<tensorrt_llm::runtime::GenerationInput, std::allocator<tensorrt_llm::runtime::GenerationInput> > const&, std::vector<int, std::allocator<int> > const&, tensorrt_llm::batch_manager::kv_cache_manager::KVCacheManager const*) + 984
3       0x7f3b64024b25 tensorrt_llm::runtime::GptSession::generateBatched(std::vector<tensorrt_llm::runtime::GenerationOutput, std::allocator<tensorrt_llm::runtime::GenerationOutput> >&, std::vector<tensorrt_llm::runtime::GenerationInput, std::allocator<tensorrt_llm::runtime::GenerationInput> > const&, tensorrt_llm::runtime::SamplingConfig const&, std::function<void (int, bool)> const&, std::shared_ptr<tensorrt_llm::runtime::GptSession::GenerationProfiler>) + 2261
4       0x7f3b64026885 tensorrt_llm::runtime::GptSession::generate(tensorrt_llm::runtime::GenerationOutput&, tensorrt_llm::runtime::GenerationInput const&, tensorrt_llm::runtime::SamplingConfig const&, std::shared_ptr<tensorrt_llm::runtime::GptSession::GenerationProfiler>) + 3381
5       0x7f3b642c67cc tensorrt_llm::batch_manager::TrtGptModelV1::forwardAsync(std::list<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest>, std::allocator<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest> > > const&) + 3036
6       0x7f3b642dcfe4 tensorrt_llm::executor::Executor::Impl::forwardAsync(std::list<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest>, std::allocator<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest> > >&) + 100
7       0x7f3b642dfd0c tensorrt_llm::executor::Executor::Impl::executionLoop() + 380
8       0x7f3c66de1253 /usr/lib/x86_64-linux-gnu/libstdc++.so.6(+0xdc253) [0x7f3c66de1253]
9       0x7f3c66b70ac3 /usr/lib/x86_64-linux-gnu/libc.so.6(+0x94ac3) [0x7f3c66b70ac3]
10      0x7f3c66c02850 /usr/lib/x86_64-linux-gnu/libc.so.6(+0x126850) [0x7f3c66c02850]
[TensorRT-LLM][ERROR] Encountered an error in forwardAsync function: Input tensor 'attn_q_lora_weights_pointers_0' not found; expected shape: (-1, 2) (/app/tensorrt_llm/cpp/tensorrt_llm/runtime/tllmRuntime.cpp:151)
1       0x7f2228075d4b tensorrt_llm::runtime::TllmRuntime::setInputTensors(int, std::unordered_map<std::string, std::shared_ptr<tensorrt_llm::runtime::ITensor>, std::hash<std::string>, std::equal_to<std::string>, std::allocator<std::pair<std::string const, std::shared_ptr<tensorrt_llm::runtime::ITensor> > > > const&) + 651
2       0x7f2228023ab8 tensorrt_llm::runtime::GptSession::executeContextStep(std::vector<tensorrt_llm::runtime::GenerationInput, std::allocator<tensorrt_llm::runtime::GenerationInput> > const&, std::vector<int, std::allocator<int> > const&, tensorrt_llm::batch_manager::kv_cache_manager::KVCacheManager const*) + 984
3       0x7f2228024b25 tensorrt_llm::runtime::GptSession::generateBatched(std::vector<tensorrt_llm::runtime::GenerationOutput, std::allocator<tensorrt_llm::runtime::GenerationOutput> >&, std::vector<tensorrt_llm::runtime::GenerationInput, std::allocator<tensorrt_llm::runtime::GenerationInput> > const&, tensorrt_llm::runtime::SamplingConfig const&, std::function<void (int, bool)> const&, std::shared_ptr<tensorrt_llm::runtime::GptSession::GenerationProfiler>) + 2261
4       0x7f2228026885 tensorrt_llm::runtime::GptSession::generate(tensorrt_llm::runtime::GenerationOutput&, tensorrt_llm::runtime::GenerationInput const&, tensorrt_llm::runtime::SamplingConfig const&, std::shared_ptr<tensorrt_llm::runtime::GptSession::GenerationProfiler>) + 3381
5       0x7f22282c67cc tensorrt_llm::batch_manager::TrtGptModelV1::forwardAsync(std::list<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest>, std::allocator<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest> > > const&) + 3036
6       0x7f22282dcfe4 tensorrt_llm::executor::Executor::Impl::forwardAsync(std::list<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest>, std::allocator<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest> > >&) + 100
7       0x7f22282dfd0c tensorrt_llm::executor::Executor::Impl::executionLoop() + 380
8       0x7f2327de1253 /usr/lib/x86_64-linux-gnu/libstdc++.so.6(+0xdc253) [0x7f2327de1253]
9       0x7f2327b70ac3 /usr/lib/x86_64-linux-gnu/libc.so.6(+0x94ac3) [0x7f2327b70ac3]
10      0x7f2327c02850 /usr/lib/x86_64-linux-gnu/libc.so.6(+0x126850) [0x7f2327c02850]
E0621 13:20:02.074615 129 backend_model.cc:691] ERROR: Failed to create instance: failed to run warmup sample 'lora_pbap_warmup': Executor failed process requestId 1 due to the following error: Encountered an error in forwardAsync function: Input tensor 'attn_q_lora_weights_pointers_0' not found; expected shape: (-1, 2) (/app/tensorrt_llm/cpp/tensorrt_llm/runtime/tllmRuntime.cpp:151)
1       0x7f3b64075d4b tensorrt_llm::runtime::TllmRuntime::setInputTensors(int, std::unordered_map<std::string, std::shared_ptr<tensorrt_llm::runtime::ITensor>, std::hash<std::string>, std::equal_to<std::string>, std::allocator<std::pair<std::string const, std::shared_ptr<tensorrt_llm::runtime::ITensor> > > > const&) + 651
2       0x7f3b64023ab8 tensorrt_llm::runtime::GptSession::executeContextStep(std::vector<tensorrt_llm::runtime::GenerationInput, std::allocator<tensorrt_llm::runtime::GenerationInput> > const&, std::vector<int, std::allocator<int> > const&, tensorrt_llm::batch_manager::kv_cache_manager::KVCacheManager const*) + 984
3       0x7f3b64024b25 tensorrt_llm::runtime::GptSession::generateBatched(std::vector<tensorrt_llm::runtime::GenerationOutput, std::allocator<tensorrt_llm::runtime::GenerationOutput> >&, std::vector<tensorrt_llm::runtime::GenerationInput, std::allocator<tensorrt_llm::runtime::GenerationInput> > const&, tensorrt_llm::runtime::SamplingConfig const&, std::function<void (int, bool)> const&, std::shared_ptr<tensorrt_llm::runtime::GptSession::GenerationProfiler>) + 2261
4       0x7f3b64026885 tensorrt_llm::runtime::GptSession::generate(tensorrt_llm::runtime::GenerationOutput&, tensorrt_llm::runtime::GenerationInput const&, tensorrt_llm::runtime::SamplingConfig const&, std::shared_ptr<tensorrt_llm::runtime::GptSession::GenerationProfiler>) + 3381
5       0x7f3b642c67cc tensorrt_llm::batch_manager::TrtGptModelV1::forwardAsync(std::list<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest>, std::allocator<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest> > > const&) + 3036
6       0x7f3b642dcfe4 tensorrt_llm::executor::Executor::Impl::forwardAsync(std::list<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest>, std::allocator<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest> > >&) + 100
7       0x7f3b642dfd0c tensorrt_llm::executor::Executor::Impl::executionLoop() + 380
8       0x7f3c66de1253 /usr/lib/x86_64-linux-gnu/libstdc++.so.6(+0xdc253) [0x7f3c66de1253]
9       0x7f3c66b70ac3 /usr/lib/x86_64-linux-gnu/libc.so.6(+0x94ac3) [0x7f3c66b70ac3]
10      0x7f3c66c02850 /usr/lib/x86_64-linux-gnu/libc.so.6(+0x126850) [0x7f3c66c02850]; 

Expected behavior

In-flight batching can be disabled.

actual behavior

Exception is thrown when attempting to disable

additional notes

I can see from the logs the intended deprecation of the v1 argument. I am much a happier with the fix to in-flight batching than and it would also resolve my issue.