NVIDIA / TensorRT-LLM

TensorRT-LLM provides users with an easy-to-use Python API to define Large Language Models (LLMs) and build TensorRT engines that contain state-of-the-art optimizations to perform inference efficiently on NVIDIA GPUs. TensorRT-LLM also contains components to create Python and C++ runtimes that execute those TensorRT engines.
https://nvidia.github.io/TensorRT-LLM
Apache License 2.0
8.2k stars 909 forks source link

INT8 GEMM Support? #682

Closed eycheung closed 9 months ago

eycheung commented 9 months ago

Will there be any plans to support INT8 GEMM? In the SmoothQuant paper it seems like one of the main benefits is that by quantizing both weights and activations, we can use specific integer kernels.

However, to speed up the inference, we need to quantize both weights and activations into INT8 (i.e., W8A8) to utilize the integer kernels (e.g., INT8 GEMM), which are supported by a wide range of hardware (e.g., NVIDIA GPUs, Intel CPUs, Qualcomm DSPs, etc.).

However, it seems like most of the TRT-LLM builds only support ['float16', 'bfloat16', 'float32'] for the GEMM plugin.

  1. Is INT8/INT4 GEMM support on the roadmap?
  2. Will there be any W4Q4 quantization algorithms added to AMMO in the future? It seems like this should also be doable in SmoothQuant, but just left as future work.
byshiue commented 9 months ago
  1. SQ is supported, like https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/gpt#smoothquant shows.
  2. W4A4 is still under study.
eycheung commented 9 months ago

Thanks @byshiue! Sorry I misunderstood what this does and see that there is indeed a specialized GEMM for smoothquant being set

network.plugin_config.set_smooth_quant_gemm_plugin(dtype=args.dtype)

As a follow-up, my benchmark for SmoothQuant int8 seems slower than weight-only int8 for Llama2, which seems surprising to me then if SQ is indeed using a specialized int8 kernel. Is this expected, or is this due to my build options (maybe it could be faster if I only did per_tensor?) or hardware (I am using g5 instances with A10s).

byshiue commented 9 months ago

Could you share the details of your benchmark? Also, SQ is only faster than weight only when batch size is large enough.

jdemouth-nvidia commented 9 months ago

It'd be nice if you could share the commands to reproduce the issue, indeed. However, that's not necessarily surprising. SmoothQuant (SQ) requires a bit of extra work to be performed (like the smoothing of activations). Both INT8 W/O and INT8 SQ work with INT8 weights and if the performance is limited by weight-loading (or KV cache loading), it won't make a huge difference if the activations are in INT8 (unique advantage of SQ in terms of runtime perf).

eycheung commented 9 months ago

Gotcha, thank you both! I have only done very naive benchmarking so far, e.g. just loop through a list and measure token metrics at various batch sizes. I'll check to see when the cross-over point will happen where SQ might outperform W/O. It makes sense that since this is memory bound, that the optimizations would only improve once the batch becomes large enough.

Thanks again, and I'll close this as resolved.

xiangxinhello commented 1 month ago

Thanks @byshiue! Sorry I misunderstood what this does and see that there is indeed a specialized GEMM for smoothquant being set

network.plugin_config.set_smooth_quant_gemm_plugin(dtype=args.dtype)

As a follow-up, my benchmark for SmoothQuant int8 seems slower than weight-only int8 for Llama2, which seems surprising to me then if SQ is indeed using a specialized int8 kernel. Is this expected, or is this due to my build options (maybe it could be faster if I only did per_tensor?) or hardware (I am using g5 instances with A10s).

self.smooth_quant_gemm_plugin = "int8" I set def set_smooth_quant_plugins(self, dtype: str = "auto"): self.smooth_quant_gemm_plugin = "int8" self.rmsnorm_quantization_plugin = dtype self.layernorm_quantization_plugin = dtype self.quantize_per_token_plugin = True self.quantize_tensor_plugin = True return self but error: [08/16/2024-08:54:52] [TRT-LLM] [I] Set smooth_quant_gemm_plugin to int8. [08/16/2024-08:54:52] [TRT-LLM] [I] Set rmsnorm_quantization_plugin to float16. [08/16/2024-08:54:52] [TRT-LLM] [I] Set layernorm_quantization_plugin to float16. [08/16/2024-08:54:52] [TRT-LLM] [I] Set quantize_per_token_plugin to True. [08/16/2024-08:54:52] [TRT-LLM] [I] Set quantize_tensor_plugin to True. [08/16/2024-08:54:52] [TRT-LLM] [I] Set nccl_plugin to None. [08/16/2024-08:54:52] [TRT-LLM] [I] Set use_custom_all_reduce to True. [08/16/2024-08:54:52] [TRT] [W] IElementWiseLayer with inputs QWenForCausalLM/transformer/layers/0/attention/qkv/smooth_quant_gemm/PLUGIN_V2_SmoothQuantGemm_0_output_0 and QWenForCausalLM/transformer/layers/0/attention/qkv/add/elementwise_binary/broadcast_helper/expand_dims_like/expand_dims/view/SHUFFLE_0_output_0: first input has type Int8 but second input has type Half. [08/16/2024-08:54:52] [TRT] [E] ITensor::getDimensions: Error Code 4: Internal Error (QWenForCausalLM/transformer/layers/0/attention/qkv/add/elementwise_binary/ELEMENTWISE_SUM_0: ElementWiseOperation SUM must have same input types. But they are of types Int8 and Half.) Traceback (most recent call last): File "/root/anaconda3/envs/trt_llm/bin/trtllm-build", line 8, in sys.exit(main()) File "/root/anaconda3/envs/trt_llm/lib/python3.10/site-packages/tensorrt_llm/commands/build.py", line 551, in main parallel_build(model_config, ckpt_dir, build_config, args.output_dir, File "/root/anaconda3/envs/trt_llm/lib/python3.10/site-packages/tensorrt_llm/commands/build.py", line 373, in parallel_build passed = build_and_save(rank, rank % workers, ckpt_dir, File "/root/anaconda3/envs/trt_llm/lib/python3.10/site-packages/tensorrt_llm/commands/build.py", line 340, in build_and_save engine = build_model(build_config, File "/root/anaconda3/envs/trt_llm/lib/python3.10/site-packages/tensorrt_llm/commands/build.py", line 333, in build_model return build(model, build_config) File "/root/anaconda3/envs/trt_llm/lib/python3.10/site-packages/tensorrt_llm/builder.py", line 890, in build model(inputs) File "/root/anaconda3/envs/trt_llm/lib/python3.10/site-packages/tensorrt_llm/module.py", line 40, in call output = self.forward(args, kwargs) File "/root/anaconda3/envs/trt_llm/lib/python3.10/site-packages/tensorrt_llm/models/modeling_utils.py", line 713, in forward hidden_states = self.transformer.forward(kwargs) File "/root/anaconda3/envs/trt_llm/lib/python3.10/site-packages/tensorrt_llm/models/qwen/model.py", line 196, in forward hidden_states = self.layers.forward(hidden_states, File "/root/anaconda3/envs/trt_llm/lib/python3.10/site-packages/tensorrt_llm/models/modeling_utils.py", line 327, in forward hidden_states = layer( File "/root/anaconda3/envs/trt_llm/lib/python3.10/site-packages/tensorrt_llm/module.py", line 40, in call output = self.forward(args, kwargs) File "/root/anaconda3/envs/trt_llm/lib/python3.10/site-packages/tensorrt_llm/models/qwen/model.py", line 121, in forward attention_output = self.attention( File "/root/anaconda3/envs/trt_llm/lib/python3.10/site-packages/tensorrt_llm/module.py", line 40, in call output = self.forward(*args, *kwargs) File "/root/anaconda3/envs/trt_llm/lib/python3.10/site-packages/tensorrt_llm/quantization/layers.py", line 1222, in forward qkv = self.qkv(hidden_states) File "/root/anaconda3/envs/trt_llm/lib/python3.10/site-packages/tensorrt_llm/module.py", line 40, in call output = self.forward(args, **kwargs) File "/root/anaconda3/envs/trt_llm/lib/python3.10/site-packages/tensorrt_llm/quantization/layers.py", line 147, in forward x = x + self.bias.value File "/root/anaconda3/envs/trt_llm/lib/python3.10/site-packages/tensorrt_llm/functional.py", line 321, in add return add(self, b) File "/root/anaconda3/envs/trt_llm/lib/python3.10/site-packages/tensorrt_llm/functional.py", line 2825, in elementwise_binary return _create_tensor(layer.get_output(0), layer) File "/root/anaconda3/envs/trt_llm/lib/python3.10/site-packages/tensorrt_llm/functional.py", line 607, in _create_tensor assert trt_tensor.shape.len( AssertionError: tensor QWenForCausalLM/transformer/layers/0/attention/qkv/add/elementwise_binary/ELEMENTWISE_SUM_0_output_0 has an invalid shape