NVIDIA / TensorRT-LLM

TensorRT-LLM provides users with an easy-to-use Python API to define Large Language Models (LLMs) and build TensorRT engines that contain state-of-the-art optimizations to perform inference efficiently on NVIDIA GPUs. TensorRT-LLM also contains components to create Python and C++ runtimes that execute those TensorRT engines.
https://nvidia.github.io/TensorRT-LLM
Apache License 2.0
8.16k stars 902 forks source link

Cannot support long context input vs vllm. #2122

Closed zhaocc1106 closed 2 weeks ago

zhaocc1106 commented 4 weeks ago

Environments:

GPU: 4 * 2080Ti ( 4 * 12G = 48G Memory) 
CUDA: cuda_12.4
Trtllm: 0.10.0
xinference: 0.14.1
vLLM: 0.5.4
CPU: Intel(R) Xeon(R) Gold 6242R CPU @ 3.10GHz
Mem:128G
LLM: Qwen2-7B

Question

Qwen2 support 32k context. I need more longer context input. When i use trtllm-build to build tp4 engine, i can only set max_batch_size=6 and max_input_len=4096 and max_output_len=512. If i set much more, will oom. The max context len is only 4k. If i set small max_batch_size, many request will be processed in sequence and throughput will be much smaller than vllm server.

Even more when i use --max_batch_size 1 to build as following cmd:

python convert_checkpoint.py --model_dir /data/docker_ceph/llm/Qwen2-7B-Instruct --output_dir /data/docker_ceph/llm/Qwen2-7B-Instruct/tllm_checkpoint_4gpu_tp4/ --dtype float16 --tp_size 4

trtllm-build --checkpoint_dir /data/docker_ceph/llm/Qwen2-7B-Instruct/tllm_checkpoint_4gpu_tp4/ --output_dir /data/docker_ceph/llm/Qwen2-7B-Instruct/trt_engines/fp16_4gpu/ --gemm_plugin float16 --context_fmha disable --use_custom_all_reduce disable --max_batch_size 1 --paged_kv_cache enable --max_input_len 32256 --max_output_len 512 --remove_input_padding enable

Also oom like log:

Error Code 4: Internal Error (Internal error: plugin node QWenForCausalLM/transformer/layers/0/attention/PLUGIN_V2_GPTAttention_0 requires 45945584128 bytes of scratch space, but only 11539054592 is available. Try increasing the workspace size with IBuilderConfig::setMemoryPoolLimit()

Why does per gpu need 42G(so large) memory for 7B LLM and 32k context?

Set large max_output_len is ok:

trtllm-build --checkpoint_dir /data/docker_ceph/llm/Qwen2-7B-Instruct/tllm_checkpoint_4gpu_tp4/ --output_dir /data/docker_ceph/llm/Qwen2-7B-Instruct/trt_engines/fp16_4gpu/ --gemm_plugin float16 --context_fmha disable --use_custom_all_reduce disable --max_batch_size 16 --paged_kv_cache enable --max_input_len 512 --max_output_len 32256 --remove_input_padding enable

When i use vllm, it's very easy to support 32k context. Is there any solution to support long context?

zhaocc1106 commented 3 weeks ago

Any update?

Kefeng-Duan commented 3 weeks ago

Hi, @zhaocc1106 , could you update to the latest trtllm version?

zhaocc1106 commented 3 weeks ago

Hi, @zhaocc1106 , could you update to the latest trtllm version?

The same issue with v0.11.0:

trtllm-build --checkpoint_dir /data/docker_ceph/llm/Qwen2-7B-Instruct/tllm_checkpoint_4gpu_tp4/ --output_dir /data/docker_ceph/llm/Qwen2-7B-Instruct/trt_engines/fp16_4gpu/ --gemm_plugin float16 --context_fmha disable --use_custom_all_reduce disable --max_batch_size 1 --paged_kv_cache enable --max_input_len 32256 --max_output_len 512 --remove_input_padding enable

Error log:

Error Code: 4: Internal error: plugin node QWenForCausalLM/transformer/layers/0/attention/wrapper/gpt_attention/PLUGIN_V2_GPTAttention_0 requires 45945584128 bytes of scratch space, but only 11539054592 is available. Try increasing the workspace size with IBuilderConfig::setMemoryPoolLimit().

Why does per gpu need 42G(so large) memory for only 7B LLM and 32k context?

Kefeng-Duan commented 3 weeks ago

@zhaocc1106 Could you double check that you have successfully rebuilded and reinstalled the v0.11.0, I think we have remove '--use_custom_all_reduce' knob from build flow and you will get an error message (the know is not existed) at first if your trtllm-build is based on v0.11.0

zhaocc1106 commented 3 weeks ago

@zhaocc1106 Could you double check that you have successfully rebuilded and reinstalled the v0.11.0, I think we have remove '--use_custom_all_reduce' knob from build flow and you will get an error message (the know is not existed) at first if your trtllm-build is based on v0.11.0

I use docker image: nvcr.io/nvidia/tritonserver:24.07-trtllm-python-py3

It have --use_custom_all_reduce. As following:

image
Kefeng-Duan commented 2 weeks ago

@zhaocc1106 could you try to enable --context_fmha?

zhaocc1106 commented 2 weeks ago

@zhaocc1106 could you try to enable --context_fmha?

When i use v0.11.0, 4 x 2080Ti gpu, max_batch_size 1, max_input_len 32k, as following cmd:

trtllm-build --checkpoint_dir /tmp/Qwen2-7B-Instruct/tllm_checkpoint_4gpu_tp4/ \                                                                              --output_dir /tmp/Qwen2-7B-Instruct/trt_engines/fp16_4gpu/ \                                                                              
--gemm_plugin float16 --max_batch_size 1 --paged_kv_cache enable \
--max_input_len 32560 --max_output_len 512 \
--use_custom_all_reduce disable

Will have following err log:

TensorRT-LLM][WARNING] Fall back to unfused MHA because of unsupported head size 128 in sm_75.
[TensorRT-LLM][WARNING] Fall back to unfused MHA because of unsupported head size 128 in sm_75.
[TensorRT-LLM][WARNING] Fall back to unfused MHA because of unsupported head size 128 in sm_75.
[TensorRT-LLM][WARNING] Fall back to unfused MHA because of unsupported head size 128 in sm_75.
[TensorRT-LLM][WARNING] Fall back to unfused MHA because of unsupported head size 128 in sm_75.
[TensorRT-LLM][WARNING] Fall back to unfused MHA because of unsupported head size 128 in sm_75.
[TensorRT-LLM][WARNING] Fall back to unfused MHA because of unsupported head size 128 in sm_75.
[TensorRT-LLM][WARNING] Fall back to unfused MHA because of unsupported head size 128 in sm_75.
[TensorRT-LLM][WARNING] Fall back to unfused MHA because of unsupported head size 128 in sm_75.
[TensorRT-LLM][WARNING] Fall back to unfused MHA because of unsupported head size 128 in sm_75.
[TensorRT-LLM][WARNING] Fall back to unfused MHA because of unsupported head size 128 in sm_75.
[TensorRT-LLM][WARNING] Fall back to unfused MHA because of unsupported head size 128 in sm_75.
[08/25/2024-14:12:55] [TRT-LLM] [I] Build TensorRT engine Unnamed Network 0
[08/25/2024-14:12:55] [TRT] [I] BuilderFlag::kTF32 is set but hardware does not support TF32. Disabling TF32.
[08/25/2024-14:12:55] [TRT] [W] Unused Input: position_ids
[08/25/2024-14:12:55] [TRT] [W] Detected layernorm nodes in FP16.
[08/25/2024-14:12:55] [TRT] [W] Running layernorm after self-attention in FP16 may cause overflow. Exporting the model to the latest available ONNX opset (later than opset 17) to use the INormalizationLayer, or forcing layernorm layers to run in FP32 precision can help with preserving accuracy.
[08/25/2024-14:12:55] [TRT] [W] [RemoveDeadLayers] Input Tensor position_ids is unused or used only at compile-time, but is not being removed.
[08/25/2024-14:12:55] [TRT] [I] BuilderFlag::kTF32 is set but hardware does not support TF32. Disabling TF32.
[08/25/2024-14:12:55] [TRT] [I] Global timing cache in use. Profiling results in this builder pass will be stored.
[08/25/2024-14:12:56] [TRT] [E] [defaultAllocator.cpp::allocate::19] Error Code 1: Cuda Runtime (out of memory)
[08/25/2024-14:12:56] [TRT] [W] Requested amount of GPU memory (1090001920 bytes) could not be allocated. There may not be enough free memory for allocation to succeed.
[08/25/2024-14:12:56] [TRT] [E] Error Code: 9: Skipping tactic 0x0000000000000000 due to exception [region-alloc.cpp:allocate:60] 1090001920-byte region '__mye243-consts' allocation failed.
[08/25/2024-14:12:56] [TRT] [E] Error Code: 10: Could not find any implementation for node {ForeignNode[QWenForCausalLM/transformer/vocab_embedding/value/constant/CONSTANT_0...QWenForCausalLM/transformer/layers/0/input_layernorm/rms_norm/__mul__/elementwise_binary/ELEMENTWISE_PROD_0]}.
[08/25/2024-14:12:56] [TRT] [E] IBuilder::buildSerializedNetwork: Error Code 10: Internal Error (Could not find any implementation for node {ForeignNode[QWenForCausalLM/transformer/vocab_embedding/value/constant/CONSTANT_0...QWenForCausalLM/transformer/layers/0/input_layernorm/rms_norm/__mul__/elementwise_binary/ELEMENTWISE_PROD_0]}.)

But when i use v0.11.0, 1 * 4090D gpu, max_batch_size 16, max_input_len 32k, will be ok:

trtllm-build --checkpoint_dir /tmp/Qwen2-7B-Instruct/tllm_checkpoint_4gpu_tp4/ \
--output_dir /tmp/Qwen2-7B-Instruct/trt_engines/fp16_4gpu/ \
--gemm_plugin float16 --max_batch_size 16 --paged_kv_cache enable --max_input_len 32560 --max_output_len 512

Is that because 2080Ti does not support context fmha ?

Kefeng-Duan commented 2 weeks ago

@zhaocc1106 right, we don't support 2080Ti.

zhaocc1106 commented 2 weeks ago

@zhaocc1106 right, we don't support 2080Ti.

Alright , close the issue.