NVIDIA / TensorRT-LLM

TensorRT-LLM provides users with an easy-to-use Python API to define Large Language Models (LLMs) and build TensorRT engines that contain state-of-the-art optimizations to perform inference efficiently on NVIDIA GPUs. TensorRT-LLM also contains components to create Python and C++ runtimes that execute those TensorRT engines.
https://nvidia.github.io/TensorRT-LLM
Apache License 2.0
8.19k stars 908 forks source link

I'm benchmarking llama-7b with 8 batch size in A40,but oom happened,I'm curious why 7B model need to cost too much memory? #67

Closed qihang720 closed 10 months ago

qihang720 commented 10 months ago

python benchmark.py -m llama_7b --batch_size "8" --mode plugin --input_output_len '2048,2048' --csv --max_input_len 2048 --max_output_len 2048

log error: [TRT-LLM] [E] Exception CUDA out of memory. Tried to allocate 512.00 MiB. GPU 0 has a total capacty of 44.35 GiB of which 147.88 MiB is free. Process 1107750 has 44.16 GiB memory in use. Of the allocated memory 12.00 GiB is allocated by PyTorch, and 983.50 KiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF caught while allocating memory; skipping (8, 2048, 2048)

A40 is 46GB memory, I think it's enough to support to 8 batchsize for llama-7b.

And can you tell me how to estimate the memory using in defferent batchsize of differnet models, like 7b model, the weight have to cost how much memory, and if input_output_length is 2k/2k, every batchsize should cost how much memory?

77h2l commented 10 months ago

meets the same issue on singe A10 when input/output length exceeds the recommend default value, it seems trt-llm did not optimize well under the long-sequence scene btw, do you know how to calculate the precise qps value under this benchmark script, thx

xuwei320 commented 10 months ago

try paged_kv_cache option

qihang720 commented 10 months ago

meets the same issue on singe A10 when input/output length exceeds the recommend default value, it seems trt-llm did not optimize well under the long-sequence scene btw, do you know how to calculate the precise qps value under this benchmark script, thx

I thought tokens_per_sec might be the qps value.

jdemouth-nvidia commented 10 months ago

In terms of memory consumption, you can expect that Llama 7B will take 14GB for the weights (7B and each one is 2 bytes). The KV cache, when it's not set to paged KV cache, should take

8 (batch size) x 2 (K and V) x 4,096 (max seqlen) x 32 (layers) x 32 (heads) x 128 (hidden size) x 2B (per element) = 16GB.

In total, that's nearly 32GB for TensorRT-LLM, if PyTorch reserves 12GB as the error message seems to imply, we are getting very close to the 44GB mentions in the error message. No?

qihang720 commented 10 months ago

yes, it's nearly 44GB in total, and 32GB for tensorrt-llm is the same as my understanding. So why does pytorch reserves 12GB?

juney-nvidia commented 10 months ago

allocated by PyTorch

There might be some under-the-hood behavior of PyTorch to allocate the additional memory. BTW, if you want to do perf benchmark, it is suggested to try with the C++ benchmark workflow. Can you try with it to measure the performance and memory consumption?

June

wm2012011492 commented 10 months ago

Hi @qihang720 Could you specify the --max_batch_sizeas well? I can benchmark llama_7b with bs=8 and input/output_len=2048, which takes 34.78GB. python benchmark.py -m llama_7b --batch_size "8" --mode plugin --input_output_len '2048,2048' --csv --max_input_len 2048 --max_output_len 2048 --max_batch_size 32 model_name,world_size,num_heads,num_kv_heads,num_layers,hidden_size,vocab_size,precision,batch_size,input_length,output_length,gpu_peak_mem(gb),build_time(s),tokens_per_sec,percentile95(ms),percentile99(ms),latency(ms),compute_cap llama_7b,1,32,32,32,4096,32000,float16,8,2048,2048,34.78,29.17,161.64,101385.156,101385.156,101359.884,sm86

jdemouth-nvidia commented 10 months ago

I see that @litaotju is assigned on this issue. His analysis has concluded that it’s not a bug. Quoting him:

Tracing down to TRT memory allocation. TRT needs at least 4 big block to hold the internal Tensors for decoder block.

hidden_states before qkv_gemm = MAX_BS x MAX_IN_SEQ x HIDDEN = 128x 2048x 4096 x 2 = 2147483648

qkv gemm output = MAX_BS x MAX_INP_LEN x HIDDEN x 3 (QKV) = 128 x 2048 x 4096 x 3 x 2 bytes (fp16) = 6442450944

mlp 1st FC out = MAX_BS x MAX_INP_LEN x inter size = 128 x 2048 x 11008 x 2 = 5771362304

mlp activation out (assume not fused in plugin mode) = 1st FC out = 5771362304

In this case, the memory consumption is correct, and no bugs. And the "activation" memory handled by TRT is always computed by max shape, like the max batch size 128, even though, user requests to use bs = 8 in the runtime.

If user build the engine with max bs == 8, then the activation memory allocated by TRT can be reduced by 128/8 = 16x, which is about 1GB.