NVIDIA / TensorRT-LLM

TensorRT-LLM provides users with an easy-to-use Python API to define Large Language Models (LLMs) and build TensorRT engines that contain state-of-the-art optimizations to perform inference efficiently on NVIDIA GPUs. TensorRT-LLM also contains components to create Python and C++ runtimes that execute those TensorRT engines.
https://nvidia.github.io/TensorRT-LLM
Apache License 2.0
8.66k stars 988 forks source link

Llama7b Int4 on Nvidia T4. Output from Triton is incorrect. #396

Open matichon-vultureprime opened 1 year ago

matichon-vultureprime commented 1 year ago

Hello folks,

I am looking to build the llama7b int4 weight and serve via Triton. I attempted constructing it and verifying whether the int4 output is correct.

However, when I built it with use_inflight_batching and paged_kv_cache and served it via Triton, I got a different output from what I previously had.

NVIDIA-SMI 545.23.06, Driver Version: 545.23.06, CUDA Version: 12.3
Branch release/0.5.0

I have provide my build step and output. 1.

make -C docker release_build CUDA_ARCHS="70-real;75-real"

2.

python build.py --model_dir model_weights/models--meta-llama--Llama-2-7b-hf/snapshots/8cca527612d856d7d32bd94f8103728d614eb852 \
               --dtype float16 \
               --remove_input_padding \
               --use_gpt_attention_plugin float16 \
               --use_gemm_plugin float16 \
               --max_output_len 1024 \
               --use_weight_only \
               --weight_only_precision 'int4'\
               --output_dir ./tmp_64/llama/7B/trt_engines/weight_only/1-gpu/

3.

python3 run.py --engine_dir=/code/tensorrt_llm/examples/llama/tmp_64/llama/7B/trt_engines/weight_only/1-gpu/ --max_output_len 100 --tokenizer_dir model_weights/models--meta-llama--Llama-2-7b-hf/snapshots/8cca527612d856d7d32bd94f8103728d614eb852 --input_text "Hello how are you?"
Running the float16 engine ...
Input: "Hello how are you?"
Output: "
 nobody is perfect, but I think I'm pretty close.
I'm a 20 year old guy from the Netherlands. I'm a very open minded person, I like to have fun and I like to make people happy. I'm a very outgoing person, I like to go out and have fun. I'm a very romantic person, I like to give gifts and I like to make people happy. I'm a very car"

When i convert into inflight and paged. 1.

python build.py --model_dir model_weights/models--meta-llama--Llama-2-7b-hf/snapshots/8cca527612d856d7d32bd94f8103728d614eb852 \
               --dtype float16 \
               --remove_input_padding \
               --use_gpt_attention_plugin float16 \
               --use_gemm_plugin float16 \
               --use_weight_only \
               --use_inflight_batching \
               --paged_kv_cache \
               --weight_only_precision 'int4'\
               --output_dir ./tmp_prod/llama/7B/trt_engines/weight_only/1-gpu/

2.Start script.

python /opt/scripts/launch_triton_server.py --model_repo /all_models/inflight_batcher_llm --world_size 1
  1. Curl request.
    curl -X POST localhost:8000/v2/models/ensemble/generate -d \
    '{
    "text_input": "Hello how are you?",
    "parameters": {
    "max_tokens": 100,
    "bad_words":[""],
    "stop_words":[""],
    "stream": false,
    "temperature": 1
    }
    }'
  2. Response.
    {"model_name":"ensemble",
    "model_version":"1",
    "sequence_end":false,
    "sequence_id":0,
    "sequence_start":false,
    "text_output":
    "<s> Hello how are you?reso Vid пу пуlitlitlitlitouteouteouterouterouterouterelslitlitouterouterouterouterouterouteribaibaibaibaibaibaasionasion ga gaalandalandaland列列列列列列ovokö roman roman roman roman roman roman roman roman roman roman Roman roman roman Heinrich Roman Roman Heinrich Roman Roman Roman Apple Apple Warner Warner Warner Warner decengoengo bat roman Alc AlcVD->{->{->{->{->{->{->{->{->{->{->{RuntimeRuntimeaml->{amlaml Alc angularjs angularjsanjeVD"}
jfolz commented 12 months ago

From experience: The basic int4 weight only mode (not GPTQ, AWQ) should not be used. Even the 70B model generated nonsense when using it. Try the int8 mode instead, maybe it fits. LLM runtimes in general have limited support for older Volta and Turing generations. Maybe give llama.cpp a try. It wasn't particularly fast compared to TensorRT-LLM, but it's easy to use and their 5bit quantization mode is quite accurate.

matichon-vultureprime commented 12 months ago

Hi folk, Recently, I carried out a test that I'd like to share with all of you.

Hypothesis: Llama2 int4 weight (weight only) should work all across architecture (SM70, SM75, SM80, SM86, SM89 ,SM90)

Result. T4 (SM75) int4 TRT-LLM backend produces incorrect output. T4 (SM75) fp16 TRT-LLM backend produces correct output. V100 (SM70) int4 TRT-LLM backend produces correct output. V100 (SM70) fp16 TRT-LLM backend produces correct output. A10G (SM80) int4 TRT-LLM backend produces correct output. A10G (SM80) fp16 TRT-LLM backend produces correct output. A100 (SM80) int4 TRT-LLM backend produces correct output. A100 (SM80) fp16 TRT-LLM backend produces correct output.

Wish this report will helpful.

juney-nvidia commented 12 months ago

@matichon-vultureprime Thanks for reporting this. We haven't thoroughly validated TensorRT-LLM on Turing hardware.

Let me bring this to our product team's attention firstly.