NVIDIA / TensorRT-LLM

TensorRT-LLM provides users with an easy-to-use Python API to define Large Language Models (LLMs) and build TensorRT engines that contain state-of-the-art optimizations to perform inference efficiently on NVIDIA GPUs. TensorRT-LLM also contains components to create Python and C++ runtimes that execute those TensorRT engines.
https://nvidia.github.io/TensorRT-LLM
Apache License 2.0
7.5k stars 818 forks source link

Why are the human eval scores of smoothquant and int8_weight_only very very low? #453

Closed activezhao closed 7 months ago

activezhao commented 7 months ago

We use A10, and model of CodeLlama-7B which from HuggingFace.

And I use the latest tensorrtllm_backend and TensorRT-LLM of main branch.

We tested three cases.

1、Normal

The human eval score is 30.48, this value is within expectations.

python build.py --model_dir /tensorrtllm_backend/CodeLlama-7b-hf/  \
                --dtype float16 \
                --remove_input_padding \
                --use_gpt_attention_plugin float16 \
                --paged_kv_cache \
                --use_inflight_batching \
                --enable_context_fmha \
                --use_gemm_plugin float16 \
                --output_dir /tensorrtllm_backend/trt_llama_7b_fp16_kv_cache_inflight_batching_stop/4-gpu/  \
                --vocab_size 32016  \
                --rotary_base 1000000  \
                --max_batch_size 32  \
                --world_size 4 \
                --tp_size 4

2、smoothquant

https://github.com/NVIDIA/TensorRT-LLM/blob/main/examples/llama/README.md#smoothquant:~:text=1%2Dgpu%20%5C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%2D%2Dtest_hf-,SmoothQuant,-The%20smoothquant%20supports

The human eval score is 4.87, this value is not within expectations.

python3 hf_llama_convert.py -i /tensorrtllm_backend/CodeLlama-7b-hf/   \
                -o /tensorrtllm_backend/smooth_llama_7B/sq0.5/ \
                -sq 0.5 \
                --tensor-parallelism 4 \
                --storage-type fp16

# Build model for SmoothQuant in the _per_token_ + _per_channel_ mode
python3 build.py --ft_model_dir=/tensorrtllm_backend/smooth_llama_7B/sq0.5/4-gpu/ \
                --dtype float16 \
                --remove_input_padding \
                --use_gpt_attention_plugin float16 \
                --paged_kv_cache \
                --use_inflight_batching \
                --enable_context_fmha \
                --use_gemm_plugin float16 \
                --max_batch_size 64  \
                --vocab_size 32016  \
                --rotary_base 1000000  \
                --use_smooth_quant \
                --per_token \
                --per_channel \
                --world_size 4 \
                --tp_size 4 \
                --output_dir /tensorrtllm_backend/trt_llama_7b_sq_pt_pc/sq0.5/4-gpu/

3、int8_weight_only

https://github.com/NVIDIA/TensorRT-LLM/blob/main/examples/llama/README.md#int8-kv-cache:~:text=INT8%20KV-,cache,-INT8%20KV%20cache

The human eval score is 7.31, this value is also not within expectations.

python3 hf_llama_convert.py -i /tensorrtllm_backend/CodeLlama-7b-hf/    \
                -o /tensorrtllm_backend/smooth_llama_7B_int8_kv_cache/ \
                --calibrate-kv-cache \
                -t fp16 \
                --tensor-parallelism 4

# Build model with both INT8 weight-only and INT8 KV cache enabled
python build.py --ft_model_dir=/tensorrtllm_backend/smooth_llama_7B_int8_kv_cache/4-gpu/ \
                --dtype float16 \
                --remove_input_padding \
                --use_gpt_attention_plugin float16 \
                --paged_kv_cache \
                --use_inflight_batching \
                --enable_context_fmha \
                --use_gemm_plugin float16 \
                --max_batch_size 64  \
                --vocab_size 32016  \
                --output_dir /tensorrtllm_backend/trt_llama_7b_int8_kv_cache_weight_only/4-gpu \
                --int8_kv_cache \
                --use_weight_only \
                --world_size 4 \
                --tp_size 4

In fact, when we use smoothquant and int8_weight_only, the inference results have many "\n".

-7e9a16d3f229

I think these results are abnormal. Is there something wrong?

Thanks.

lyers179 commented 7 months ago

same question on A40 as codellama 13b

activezhao commented 7 months ago

same question on A40 as codellama 13b

Hi @lyers179 Do you have any solutions?

jdemouth-nvidia commented 7 months ago

Are you using the main branch or the release branch? Have you tried using those quantization techniques outside of TensorRT-LLM (to make sure the techniques work for those models and the issue is “with” TensorRT-LLM). Thanks!

activezhao commented 7 months ago

Are you using the main branch or the release branch? Have you tried using those quantization techniques outside of TensorRT-LLM (to make sure the techniques work for those models and the issue is “with” TensorRT-LLM). Thanks!

@jdemouth-nvidia I use the latest main branch about 2 weeks ago.

I used int8_weight_only of Fastertransformer in A10 in the past, and the human eval score was not so low.

What's more, do you have the codellama-7b human eval scores of TensorRT-LLM with SmoothQuant and int8_weight_only? Do they have relatively normal human eval scores?

jdemouth-nvidia commented 7 months ago

I’ve just asked one member of our team to take a look at this issue. Thanks

activezhao commented 7 months ago

I’ve just asked one member of our team to take a look at this issue. Thanks

@jdemouth-nvidia OK, thank u.

Tracin commented 7 months ago

@activezhao Could you please build another simple version in order to debug? I ran this on A100 with 30.49 Humaneval scores. python build.py --ft_model_dir= \ --dtype float16 \ --remove_input_padding \ --use_gpt_attention_plugin float16 \ --enable_context_fmha \ --use_gemm_plugin float16 \ --vocab_size 32016 \ --rotary_base 1000000 \ --output_dir \ --use_weight_only \

activezhao commented 7 months ago

@activezhao Could you please build another simple version in order to debug? I ran this on A100 with 30.49 Humaneval scores. python build.py --ft_model_dir= \ --dtype float16 \ --remove_input_padding \ --use_gpt_attention_plugin float16 \ --enable_context_fmha \ --use_gemm_plugin float16 \ --vocab_size 32016 \ --rotary_base 1000000 \ --output_dir \ --use_weight_only \

Hi @Tracin What is your command of hf_llama_convert.py? Could u please give the command of every step?

And do u use the latest TensorRT-LLM of main branch?

Tracin commented 7 months ago

@activezhao Sorry, you do not need hf_llama_convert if we skip int8_kv_cache, you can change ft_model_dir to model_dir. Please try it.

activezhao commented 7 months ago

@activezhao Sorry, you do not need hf_llama_convert if we skip int8_kv_cache, you can change ft_model_dir to model_dir. Please try it.

@Tracin OK, thanks a lot, I am trying it now.

But, if you add int8_kv_cache, will anything unusual happen?

activezhao commented 7 months ago

Hi @Tracin I just skip the int8_kv_cache, and the human eval is 29.9 on A10, it's normal now.

However,

the throughput of use_weight_only + int8_kv_cache can be increased by 24% compared with fp16,

the throughput of only use_weight_only can be increased by 10% compared with fp16.

Is there any problem with the logical processing of int8_kv_cache? If so, when can it be fixed?

python build.py --model_dir=/tensorrtllm_backend/CodeLlama-7b-hf/ \
                --dtype float16 \
                --remove_input_padding \
                --use_gpt_attention_plugin float16 \
                --paged_kv_cache \
                --use_inflight_batching \
                --enable_context_fmha \
                --use_gemm_plugin float16 \
                --max_batch_size 64  \
                --vocab_size 32016  \
                --rotary_base 1000000  \
                --output_dir /tensorrtllm_backend/trt_llama_7b_int8_weight_only_no_kv_cache \
                --use_weight_only \
                --world_size 4 \
                --tp_size 4
Tracin commented 7 months ago

@activezhao int8_kv_cache might have bugs, however SQ model also have bad accuracy which does not use int8_kv_cache, so I think they are different issues, I will start from int8_kv_cache first.

activezhao commented 7 months ago

@activezhao int8_kv_cache might have bugs, however SQ model also have bad accuracy which does not use int8_kv_cache, so I think they are different issues, I will start from int8_kv_cache first.

@Tracin OK, got it, and looking forward to your good news.

Tracin commented 7 months ago

@activezhao I have reproduce the bad accuracy of SQ and INT8-KV model. Could you try build SQ model on release-0.5.0 branch please?

activezhao commented 7 months ago

@activezhao I have reproduce the bad accuracy of SQ and INT8-KV model. Could you try build SQ model on release-0.5.0 branch please?

@Tracin Cool, have you fixed this problem already?

You need me rebuild SQ model on release-0.5.0 branch for test?

In fact, I just use the latest main branch of tensorrtllm_backend and TensorRT-LLM, because of the stop_words problem, is main branch just OK?

https://github.com/triton-inference-server/tensorrtllm_backend/issues/47

activezhao commented 7 months ago

@activezhao I have reproduce the bad accuracy of SQ and INT8-KV model. Could you try build SQ model on release-0.5.0 branch please?

Hi @Tracin Is there any new progress about SQ and INT8-KV now?

Tracin commented 7 months ago

@activezhao SQ has been fixed and has 31.7 accuracy and the fix will be pushed to the github main branch this week or the next week. And INT8-KV works correctly with v0.6.1.

wjueyao commented 7 months ago

@Tracin I followed the above steps to see if INT8-KV works. I used main branch. However, I got an error when running build.py

Traceback (most recent call last):
  File "/tensorrtllm_backend/tensorrt_llm/examples/llama/build.py", line 839, in <module>
    build(0, args)
  File "/tensorrtllm_backend/tensorrt_llm/examples/llama/build.py", line 783, in build
    engine = build_rank_engine(builder, builder_config, engine_name,
  File "/tensorrtllm_backend/tensorrt_llm/examples/llama/build.py", line 641, in build_rank_engine
    load_from_binary(tensorrt_llm_llama,
  File "/tensorrtllm_backend/tensorrt_llm/examples/llama/weight.py", line 917, in load_from_binary
    if not use_gemm_woq_plugin:
NameError: name 'use_gemm_woq_plugin' is not defined

Related issue https://github.com/triton-inference-server/tensorrtllm_backend/issues/211

As for SQ, I also followed the above steps, and got an error:

Traceback (most recent call last):
  File "/tensorrtllm_backend/tensorrt_llm/examples/llama/build.py", line 828, in <module>
    args = parse_arguments()
  File "/tensorrtllm_backend/tensorrt_llm/examples/llama/build.py", line 468, in parse_arguments
    n_embd, n_head, n_layer, n_positions, vocab_size, hidden_act, inter_size, n_kv_head = parse_ft_config(
  File "/tensorrtllm_backend/tensorrt_llm/examples/llama/weight.py", line 171, in parse_ft_config
    n_embd = gpt_config.getint('llama', 'hidden_size')
  File "/usr/lib/python3.10/configparser.py", line 820, in getint
    return self._get_conv(section, option, int, raw=raw, vars=vars,
  File "/usr/lib/python3.10/configparser.py", line 810, in _get_conv
    return self._get(section, conv, option, raw=raw, vars=vars,
  File "/usr/lib/python3.10/configparser.py", line 805, in _get
    return conv(self.get(section, option, **kwargs))
  File "/usr/lib/python3.10/configparser.py", line 783, in get
    d = self._unify_values(section, vars)
  File "/usr/lib/python3.10/configparser.py", line 1154, in _unify_values
    raise NoSectionError(section) from None
configparser.NoSectionError: No section: 'llama'

Parameters I used is identical to the above steps, however, both sq and int8 kv raises error.

Tracin commented 7 months ago

@wjueyao Hi, for smoothquant problem, I think you did not give the right path of model ft_model_dir (reminder, a '{}-gpu' folder will be created)

wjueyao commented 7 months ago

@wjueyao Hi, for smoothquant problem, I think you did not give the right path of model ft_model_dir (reminder, a '{}-gpu' folder will be created)

@Tracin Thanks for the quick reply! You are right. There is a typo on my ft_model_dir .

However, after I changed to the correct path, I got the following error

Traceback (most recent call last):
  File "/tensorrtllm_backend/tensorrt_llm/examples/llama/build.py", line 839, in <module>
    build(0, args)
  File "/tensorrtllm_backend/tensorrt_llm/examples/llama/build.py", line 783, in build
    engine = build_rank_engine(builder, builder_config, engine_name,
  File "/tensorrtllm_backend/tensorrt_llm/examples/llama/build.py", line 710, in build_rank_engine
    tensorrt_llm_llama(*inputs)
  File "/usr/local/lib/python3.10/dist-packages/tensorrt_llm/module.py", line 40, in __call__
    return self.forward(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/tensorrt_llm/models/llama/model.py", line 379, in forward
    hidden_states = super().forward(input_ids, position_ids, use_cache,
  File "/usr/local/lib/python3.10/dist-packages/tensorrt_llm/models/llama/model.py", line 260, in forward
    hidden_states = layer(
  File "/usr/local/lib/python3.10/dist-packages/tensorrt_llm/module.py", line 40, in __call__
    return self.forward(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/tensorrt_llm/models/llama/model.py", line 117, in forward
    attention_output = self.attention(hidden_states,
  File "/usr/local/lib/python3.10/dist-packages/tensorrt_llm/module.py", line 40, in __call__
    return self.forward(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/tensorrt_llm/quantization/layers.py", line 1103, in forward
    context, past_key_value = gpt_attention(
  File "/usr/local/lib/python3.10/dist-packages/tensorrt_llm/graph_rewriting.py", line 564, in wrapper
    outs = f(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/tensorrt_llm/functional.py", line 3548, in gpt_attention
    plug_inputs = [i.trt_tensor for i in plug_inputs]
  File "/usr/local/lib/python3.10/dist-packages/tensorrt_llm/functional.py", line 3548, in <listcomp>
    plug_inputs = [i.trt_tensor for i in plug_inputs]
AttributeError: 'NoneType' object has no attribute 'trt_tensor'
Tracin commented 7 months ago

@wjueyao I think that is an individual issue, please open a new issue and @me there please!