NVIDIA / TensorRT-LLM

TensorRT-LLM provides users with an easy-to-use Python API to define Large Language Models (LLMs) and build TensorRT engines that contain state-of-the-art optimizations to perform inference efficiently on NVIDIA GPUs. TensorRT-LLM also contains components to create Python and C++ runtimes that execute those TensorRT engines.
https://nvidia.github.io/TensorRT-LLM
Apache License 2.0
8.67k stars 990 forks source link

gpu memory leak when max_tokens = 1 and gather_all_token_logits #2335

Open anaivebird opened 1 month ago

anaivebird commented 1 month ago

System Info

Who can help?

@byshiue @juney-nvidia @ncomly-nvidia

Information

Tasks

Reproduction

git clone https://github.com/NVIDIA/TensorRT-LLM
cd examples/qwen
export CUDA_VISIBLE_DEVICES=0

convert_script=./convert_checkpoint.py
quantize_script=../quantization/quantize.py
model_dir=/tmp/qwen-7b
output_dir=/tmp/engine
tp=1

# try to remove max_batch_size, use default 256

# fp16 baseline
python3 $convert_script \
            --model_dir $model_dir \
            --output_dir $output_dir/qwen-checkpoint-${tp}gpu/ \
            --dtype float16 \
            --tp_size $tp

# 1. select best max_num_tokens, you can try adjust max_num_tokens
# e.g. increase max_num_tokens from 8192.

trtllm-build --checkpoint_dir $output_dir/qwen-checkpoint-${tp}gpu/ \
            --output_dir $output_dir/qwen-trt-engine-fusion-${tp}gpu/ \
            --gemm_plugin float16 \
            --gather_all_token_logits

cd examples/apps
python3 ./openai_server.py  /tmp/engine/qwen-trt-engine-fusion-1gpu --tokenizer /tmp/qwen-7b
import requests
import json

url = "http://localhost:8000/v1/chat/completions"
headers = {
    "Content-Type": "application/json"
}
payload = {
    "model": 'qwen',
    "top_k": 5,
    "top_p": 0.85,
    "temperature": 0.4,
    "messages": [
        {
            "role": "user",
            "content": "<s>Question: Students are investigating the effects of different fertilizers on plant growth. Which units would be best to measure the mass of the fertilizer used?\nAnswer: milligrams<s>Question: Students are investigating the effects of different fertilizers on plant growth. Which units would be best to measure the mass of the fertilizer used?\nAnswer: milligrams<s>Question: Students are investigating the effects of different fertilizers on plant growth. Which units would be best to measure the mass of the fertilizer used?\nAnswer: milligrams<s>Question: Students are investigating the effects of different fertilizers on plant growth. Which units would be best to measure the mass of the fertilizer used?\nAnswer: milligrams<s>Question: Students are investigating the effects of different fertilizers on plant growth. Which units would be best to measure the mass of the fertilizer used?\nAnswer: milligrams<s>Question: Students are investigating the effects of different fertilizers on plant growth. Which units would be best to measure the mass of the fertilizer used?\nAnswer: milligrams"
        }
    ],
    "max_tokens": 1,
    "repetition_penalty": 1.05,
    "stream": False,
}

# Function to send a single request
def send_request():
    response = requests.post(url, headers=headers, data=json.dumps(payload))
    return response.json()

# Send the request 1000 times
for i in range(1000):
    result = send_request()
    print(result)
    print(i)

print("Completed 1000 requests.")

Expected behavior

1000 requests should finished normally

actual behavior


[TensorRT-LLM][ERROR] Encountered an error in forwardAsync function: [TensorRT-LLM][ERROR] CUDA runtime error in ::cudaMallocAsync(ptr, n, mMemPool->getPool(), mCudaStream->get()): out of memory (/home/work/xingwuFileSystem/24.10.15/TensorRT-LLM/cpp/tensorrt_llm/runtime/tllmBuffers.h:125)
1       0x7f1ddb698e13 void tensorrt_llm::common::check<cudaError>(cudaError, char const*, char const*, int) + 147
2       0x7f1ddd3d9e43 tensorrt_llm::runtime::BufferManager::gpu(nvinfer1::Dims64, nvinfer1::DataType) const + 515
3       0x7f1ddd8bcab3 tensorrt_llm::batch_manager::RuntimeBuffers::reshape(tensorrt_llm::runtime::TllmRuntime const&, tensorrt_llm::runtime::ModelConfig const&, tensorrt_llm::runtime::WorldConfig const&) + 1923
4       0x7f1ddd8c1618 tensorrt_llm::batch_manager::RuntimeBuffers::prepareStep[abi:cxx11](std::vector<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest>, std::allocator<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest> > > const&, std::vector<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest>, std::allocator<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest> > > const&, int, int, tensorrt_llm::batch_manager::DecoderBuffers&, tensorrt_llm::batch_manager::kv_cache_manager::KVCacheManager*, tensorrt_llm::batch_manager::kv_cache_manager::KVCacheManager*, tensorrt_llm::batch_manager::rnn_state_manager::RnnStateManager*, std::map<unsigned long, std::shared_ptr<std::vector<tensorrt_llm::runtime::LoraCache::TaskLayerModuleConfig, std::allocator<tensorrt_llm::runtime::LoraCache::TaskLayerModuleConfig> > >, std::less<unsigned long>, std::allocator<std::pair<unsigned long const, std::shared_ptr<std::vector<tensorrt_llm::runtime::LoraCache::TaskLayerModuleConfig, std::allocator<tensorrt_llm::runtime::LoraCache::TaskLayerModuleConfig> > > > > > const&, tensorrt_llm::runtime::TllmRuntime const&, tensorrt_llm::runtime::ModelConfig const&, tensorrt_llm::runtime::WorldConfig const&) + 104
5       0x7f1ddd8e3bc7 tensorrt_llm::batch_manager::TrtGptModelInflightBatching::prepareBuffers[abi:cxx11](std::vector<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest>, std::allocator<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest> > > const&, std::vector<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest>, std::allocator<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest> > > const&, int) + 167
6       0x7f1ddd8e3d0e tensorrt_llm::batch_manager::TrtGptModelInflightBatching::executeStep(std::vector<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest>, std::allocator<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest> > > const&, std::vector<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest>, std::allocator<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest> > > const&, int) + 62
7       0x7f1ddd8e3ece tensorrt_llm::batch_manager::TrtGptModelInflightBatching::executeBatch(tensorrt_llm::batch_manager::ScheduledRequests const&) + 222
8       0x7f1ddd8e4490 tensorrt_llm::batch_manager::TrtGptModelInflightBatching::forwardAsync(std::__cxx11::list<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest>, std::allocator<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest> > > const&) + 1376
9       0x7f1ddd9162f1 tensorrt_llm::executor::Executor::Impl::forwardAsync(std::__cxx11::list<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest>, std::allocator<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest> > >&) + 353
10      0x7f1ddd91b260 tensorrt_llm::executor::Executor::Impl::executionLoop() + 896
11      0x7f1f8d8b0253 /usr/lib/x86_64-linux-gnu/libstdc++.so.6(+0xdc253) [0x7f1f8d8b0253]
12      0x7f1fc9969ac3 /usr/lib/x86_64-linux-gnu/libc.so.6(+0x94ac3) [0x7f1fc9969ac3]
13      0x7f1fc99fb850 /usr/lib/x86_64-linux-gnu/libc.so.6(+0x126850) [0x7f1fc99fb850]
[10/15/2024-07:46:52] [TRT-LLM] [E] Error in thread await_response_thread: Cannot get the result for a response with an error (/home/jenkins/agent/workspace/LLM/main/L0_PostMerge/llm/cpp/tensorrt_llm/executor/responseImpl.h:69)
1       0x7f1ddb5a8532 /home/work/.local/lib/python3.10/site-packages/tensorrt_llm/libs/libtensorrt_llm.so(+0x75d532) [0x7f1ddb5a8532]
2       0x7f1e0e527dff /home/work/.local/lib/python3.10/site-packages/tensorrt_llm/bindings.cpython-310-x86_64-linux-gnu.so(+0xb8dff) [0x7f1e0e527dff]
3       0x7f1e0e4d04e5 /home/work/.local/lib/python3.10/site-packages/tensorrt_llm/bindings.cpython-310-x86_64-linux-gnu.so(+0x614e5) [0x7f1e0e4d04e5]
4       0x559f0384d10e python3(+0x15a10e) [0x559f0384d10e]
5       0x559f03843a7b _PyObject_MakeTpCall + 603
6       0x559f0385915d python3(+0x16615d) [0x559f0385915d]
7       0x559f0384b8a8 _PyObject_GenericGetAttrWithDict + 1128
8       0x559f03849e3d PyObject_GetAttr + 77
9       0x559f0383b971 _PyEval_EvalFrameDefault + 24001
10      0x559f0385ba51 python3(+0x168a51) [0x559f0385ba51]
11      0x559f038385d7 _PyEval_EvalFrameDefault + 10791
12      0x559f0384d9fc _PyFunction_Vectorcall + 124
13      0x559f0383645c _PyEval_EvalFrameDefault + 2220
14      0x559f0384d9fc _PyFunction_Vectorcall + 124
15      0x559f0383645c _PyEval_EvalFrameDefault + 2220
16      0x559f0385ba51 python3(+0x168a51) [0x559f0385ba51]
17      0x559f03984f3a python3(+0x291f3a) [0x559f03984f3a]
18      0x559f03979ef8 python3(+0x286ef8) [0x559f03979ef8]
19      0x7f1fc9969ac3 /usr/lib/x86_64-linux-gnu/libc.so.6(+0x94ac3) [0x7f1fc9969ac3]
20      0x7f1fc99fb850 /usr/lib/x86_64-linux-gnu/libc.so.6(+0x126850) [0x7f1fc99fb850]

additional notes

no

anaivebird commented 1 month ago

gpu memory leak when max_tokens = 1

Superjomn commented 1 month ago

Can you try it without gather_all_token_logits? For the case with gather_all_token_logits, we need to investigate it.

anaivebird commented 1 month ago

Thanks, to the best of my memory, without gather_all_token_logits, it works well.

Superjomn commented 1 month ago

Got it, it may be the issue with gather_all_token_logits, we will reproduce it and investigate, thanks for sharing the issue.

syuoni commented 4 weeks ago

Hi @anaivebird , thanks for reporting this issue, and I can reproduce it from my side.

Logits tensor takes a lot of memory. In your case, let's say the context length is 300, the vocab size is 151851, so each context logits tensor takes 300 * 151851 * 4 = 182221200 bytes = 0.17 GB memory. Note the logits is float32 dtype.

The model is 7B, so its weights take 7 * 2 = 14 GB memory. Let's say the activation memory and runtime buffers take additional 1GB memory. The free_gpu_memory_fraction is 0.8 by default in openai_server.py, so kv cache pool takes (80-14-1) * 0.8 = 52 GB memory. So, the remaining memory is 80 - 14 - 1 - 52 = 13 GB, which can hold 13 / 0.17 = 76.5 requests' logits tensors.

In my experiments, it works if I set --max_batch_size 64 when calling trtllm-build. Could you please try with this? Alternatively, you may use a smaller value for free_gpu_memory_fraction in openai_server.py, which allows a larger max_batch_size.

cc @yweng0828 for viz.

Thanks!

anaivebird commented 4 weeks ago

It seems necessary to reserve the corresponding memory based on the vocabulary size prompt token length max_batch_size, right?

syuoni commented 4 weeks ago

It seems necessary to reserve the corresponding memory based on the vocabulary size prompt token length max_batch_size, right?

Yes. But currently, we don't reserve the maximum size of logits like activation memory and other runtime buffers, because logits typically take too much memory.

anaivebird commented 4 weeks ago

Hi @anaivebird , thanks for reporting this issue, and I can reproduce it from my side.

Logits tensor takes a lot of memory. In your case, let's say the context length is 300, the vocab size is 151851, so each context logits tensor takes 300 151851 4 = 182221200 bytes = 0.17 GB memory. Note the logits is float32 dtype.

The model is 7B, so its weights take 7 2 = 14 GB memory. Let's say the activation memory and runtime buffers take additional 1GB memory. The free_gpu_memory_fraction is 0.8 by default in openai_server.py, so kv cache pool takes (80-14-1) 0.8 = 52 GB memory. So, the remaining memory is 80 - 14 - 1 - 52 = 13 GB, which can hold 13 / 0.17 = 76.5 requests' logits tensors.

In my experiments, it works if I set --max_batch_size 64 when calling trtllm-build. Could you please try with this? Alternatively, you may use a smaller value for free_gpu_memory_fraction in openai_server.py, which allows a larger max_batch_size.

cc @yweng0828 for viz.

Thanks!

Shown on https://github.com/NVIDIA/TensorRT-LLM/issues/2350, change free_gpu_memory_fraction does not increase gpu free memory, even change to 0.1 has no change.

syuoni commented 4 weeks ago

Shown on #2350, change free_gpu_memory_fraction does not increase gpu free memory, even change to 0.1 has no change.

Yes, you are right. I tested without the openai_server.py script, so did not find the issue. Thanks again for reporting this!