triton-inference-server / tensorrtllm_backend

The Triton TensorRT-LLM Backend
Apache License 2.0
637 stars 92 forks source link

Investigate mismatch output from Triton server and TensorRT-LLM #143

Open matichon-vultureprime opened 9 months ago

matichon-vultureprime commented 9 months ago

Hi folk,

Recently, I carried out a test that I'd like to share with all of you.

Hypothesis: Llama2 int4 weight (weight only) should work all across architecture (SM70, SM75, SM80, SM86, SM89 ,SM90)

Result. T4 (SM75) int4 TRT-LLM backend produces incorrect output. T4 (SM75) fp16 TRT-LLM backend produces correct output. V100 (SM70) int4 TRT-LLM backend produces correct output. V100 (SM70) fp16 TRT-LLM backend produces correct output. A10G (SM86) int4 TRT-LLM backend produces correct output. A10G (SM86) fp16 TRT-LLM backend produces correct output. A100 (SM80) int4 TRT-LLM backend produces correct output. A100 (SM80) fp16 TRT-LLM backend produces correct output.

Why is important. Cloud Service Providers (CSPs) own a large number of T4 GPUs and sell them at a good price. Plus, based on my tests with Llama2-7b int4, it showed a speed of 50-60 tokens each second with a batch size of 1. So, considering the performance and cost, T4 is the best in terms of performance to cost ratio in its category.

My setup step. T4 For the T4 model, I noticed a weird result when I compared the Triton output and TensorRT-LLM output. In the beginning, my plan was to use the Nvidia T4 to deploy Llama2-7b 4bit with Triton server. To achieve this, I used a particular script to create my model.

make -C docker release_build CUDA_ARCHS="70-real;75-real"

python build.py --model_dir model_weights/models--meta-llama--Llama-2-7b-hf/snapshots/8cca527612d856d7d32bd94f8103728d614eb852 \
               --dtype float16 \
               --remove_input_padding \
               --use_gpt_attention_plugin float16 \
               --use_gemm_plugin float16 \
               --max_output_len 1024 \
               --use_weight_only \
               --weight_only_precision 'int4'\
               --output_dir ./tmp_64/llama/7B/trt_engines/weight_only/1-gpu/

verify output script.

python3 run.py --engine_dir=/code/tensorrt_llm/examples/llama/tmp_64/llama/7B/trt_engines/weight_only/1-gpu/ --max_output_len 100 --tokenizer_dir model_weights/models--meta-llama--Llama-2-7b-hf/snapshots/8cca527612d856d7d32bd94f8103728d614eb852 --input_text "Hello how are you?"

Running the float16 engine ...
Input: "Hello how are you?"
Output: "
 nobody is perfect, but I think I'm pretty close.
I'm a 20 year old guy from the Netherlands. I'm a very open minded person, I like to have fun and I like to make people happy. I'm a very outgoing person, I like to go out and have fun. I'm a very romantic person, I like to give gifts and I like to make people happy. I'm a very car"

Triton input and output.

curl -X POST localhost:8000/v2/models/ensemble/generate -d \
'{
"text_input": "Hello how are you?",
"parameters": {
"max_tokens": 100,
"bad_words":[""],
"stop_words":[""],
"stream": false,
"temperature": 1
}
}'

{"model_name":"ensemble",
"model_version":"1",
"sequence_end":false,
"sequence_id":0,
"sequence_start":false,
"text_output":
"<s> Hello how are you?reso Vid пу пуlitlitlitlitouteouteouterouterouterouterelslitlitouterouterouterouterouterouteribaibaibaibaibaibaasionasion ga gaalandalandaland列列列列列列ovokö roman roman roman roman roman roman roman roman roman roman Roman roman roman Heinrich Roman Roman Heinrich Roman Roman Roman Apple Apple Warner Warner Warner Warner decengoengo bat roman Alc AlcVD->{->{->{->{->{->{->{->{->{->{->{RuntimeRuntimeaml->{amlaml Alc angularjs angularjsanjeVD"}

Report with love ❤️ by Mati

juney-nvidia commented 9 months ago

Thanks for reporting this. My understanding is that there are two issues here:

For the first one, I would discuss with the prod team about the support of Turing. Currently we haven't done thorough validation on Turing for TensorRT-LLM yet.

For the second one, do you see the inconsistency on other non-T4 HWs?

June

matichon-vultureprime commented 9 months ago

Currently, I haven't found any other inconsistencies.

If I have new information, I will let you know.

Mati

yuanphoenix commented 8 months ago

I have the same problem. and I solve by this https://github.com/triton-inference-server/tensorrtllm_backend/issues/234#issuecomment-1863851491

matichon-vultureprime commented 8 months ago

@yuanphoenix Helpful solution. let's me try.

matichon-vultureprime commented 8 months ago

@yuanphoenix Sadly. It's not work for me.

juney-nvidia commented 7 months ago

@matichon-vultureprime Sorry for replying late. Have you tried with the latest main branch to see whether the issue still exist?

June

matichon-vultureprime commented 7 months ago

@juney-nvidia Let's my try.

matichon-vultureprime commented 7 months ago

@juney-nvidia CPU architecture : x_86_64 (G4dn.8xlarge) CPU/Host memory : size Memory 128GB, Swap 16GB GPU properties. GPU name : T4 GPU memory size 16GB Libraries. TensorRT-LLM backend hash : f51f50ce77f1634e8bdea93f247f39f92313d110 Container used (Option 3) : DOCKER_BUILDKIT=1 docker build -t triton_trt_llm -f dockerfile/Dockerfile.trt_llm_backend . NVIDIA driver version : 535.104.12 OS : Ubuntu 20.04 IAM ami-0531914ef1f93b99e LLM Model : Llama-2-7b-chat-hf

Reproduction Steps

  1. Build the TensorRT-LLM engine.
    python build.py --model_dir model_weights/models--meta-llama--Llama-2-7b-chat-hf/snapshots/c1b0db933684edbfe29a06fa47eb19cc48025e93 \
                --dtype float16 \
                --remove_input_padding \
                --use_gpt_attention_plugin float16 \
                --enable_context_fmha \
                --use_gemm_plugin float16 \
                --use_weight_only \
                --max_batch_size 4 \
                --use_inflight_batching \
                --paged_kv_cache \
                --weight_only_precision 'int4'\
                --output_dir ./tmp_64/llama/7B/trt_engines/weight_only/1-gpu/
  2. Run test without Triton.
    python3 ../run.py --engine_dir=/code/tensorrt_llm/examples/llama/tmp_64/llama/7B/trt_engines/weight_only/1-gpu/ --max_output_len 100 --tokenizer_dir model_weights/models--meta-llama--Llama-2-7b-chat-hf/snapshots/c1b0db933684edbfe29a06fa47eb19cc48025e93 --input_text "[INST]Hello how are you?[/INST] "
  3. The response.
    /usr/local/lib/python3.10/dist-packages/tensorrt_llm/runtime/generation.py:869: UserWarning: The PyTorch API of nested tensors is in prototype stage and will change in the near future. (Triggered internally at /opt/pytorch/pytorch/aten/src/ATen/NestedTensorImpl.cpp:178.)
    torch.nested.nested_tensor(split_ids_list,
    Input [Text 0]: "<s> [INST]Hello how are you?[/INST] "
    Output [Text 0 Beam 0]: "Hello! I'm just an AI, I don't have feelings or emotions like humans do, so I can't say that I'm feeling any particular way. However, I'm here to help you with any questions or tasks you may have, so feel free to ask me anything!"

    During the inference. I saw VRAM consumption about 7.8GB.

  4. python /opt/scripts/launch_triton_server.py --model_repo /all_models/inflight_batcher_llm --world_size 1

Expected Behavior Triton should return response like the step 3.

Actual Behavior

[TensorRT-LLM][ERROR] CUDA runtime error in cublasGemmStridedBatchedEx(getCublasHandle(), transa, transb, m, n, k, alpha, A, AType, lda, strideA, B, BType, ldb, strideB, beta, C, CType, ldc, strideC, batchCount, computeType, mAType == CUDA_R_32F ? CUBLAS_GEMM_DEFAULT : CUBLAS_GEMM_DEFAULT_TENSOR_OP): CUBLAS_STATUS_NOT_SUPPORTED