NVIDIA / TensorRT-LLM

TensorRT-LLM provides users with an easy-to-use Python API to define Large Language Models (LLMs) and build TensorRT engines that contain state-of-the-art optimizations to perform inference efficiently on NVIDIA GPUs. TensorRT-LLM also contains components to create Python and C++ runtimes that execute those TensorRT engines.
https://nvidia.github.io/TensorRT-LLM
Apache License 2.0
8.34k stars 936 forks source link

Ada `FP8xint4` performance issue #1906

Open jcao-ai opened 3 months ago

jcao-ai commented 3 months ago

Since Ada GPUs like 4090 limit the FP8 arithmetic into fp32 accumulation, it only achieve the same max TFLOPs compared to fp16xfp16 with fp16 accumulation. Further more, according to my test, fp8xint4 even performs worse than fp16xint4 in terms of TPOT.

Wonder whether you guys also met the same situation? and does it hint that it is not a good idea to deploy fp8 on such GPUs

Njuapp commented 3 months ago

To be more precise, all weight-only kernels use fp32 accumulation, regardless of the arch version. We have done extensive experiments on L20 and L40s, where fp8xint4 is better than fp16xint4. Would you mind share your reproduction steps and results? I wonder on which condition (batch_size, intput_output_len etc.) fp8xint performs worse.

white-wolf-tech commented 1 month ago

To be more precise, all weight-only kernels use fp32 accumulation, regardless of the arch version. We have done extensive experiments on L20 and L40s, where fp8xint4 is better than fp16xint4. Would you mind share your reproduction steps and results? I wonder on which condition (batch_size, intput_output_len etc.) fp8xint performs worse.

I have conducted many experiments. INT4X FP8 on L20 has a speed close to that of using FP16 on A100. But the problem is that after the quantization of INT4X FP8, the output results are uncontrollable and the error rate is above 30%. Why is that? Is it because it doesn't support the Ada architecture?

Njuapp commented 1 month ago

To be more precise, all weight-only kernels use fp32 accumulation, regardless of the arch version.

We have done extensive experiments on L20 and L40s, where fp8xint4 is better than fp16xint4. Would you mind share your reproduction steps and results? I wonder on which condition (batch_size, intput_output_len etc.) fp8xint performs worse.

I have conducted many experiments. INT4X FP8 on L20 has a speed close to that of using FP16 on A100. But the problem is that after the quantization of INT4X FP8, the output results are uncontrollable and the error rate is above 30%. Why is that? Is it because it doesn't support the Ada architecture?

INT4xFP8 has known unstable accuracy issues, and it may or may not occur depending on the specific model. We are working on FP4xFP8 to overcome this issue. It is supported on Ada.

Njuapp commented 1 month ago

@x-transformers did you test error rate through pytorch or trtllm? I suggest you have a first try on accuracy verification on pytorch, if it turns fine then it's kernel's fault otherwise it is an inherent weakness of INT4xFP8.

white-wolf-tech commented 1 month ago

@x-transformers did you test error rate through pytorch or trtllm? I suggest you have a first try on accuracy verification on pytorch, if it turns fine then it's kernel's fault otherwise it is an inherent weakness of INT4xFP8.

PyTorch or FP16\W8A8 in trtllm, the output of model is right.L20 is supposed not to support w4a8_awq quantization.

Njuapp commented 1 month ago

@x-transformers did you test error rate through pytorch or trtllm? I suggest you have a first try on accuracy verification on pytorch, if it turns fine then it's kernel's fault otherwise it is an inherent weakness of INT4xFP8.

PyTorch or FP16\W8A8 in trtllm, the output of model is right.L20 is supposed not to support w4a8_awq quantization.

First of all, I can assure that L20 supports w4a8 AWQ kernel at inference time, absolutely. The only problem is ModelOpt (invoked by examples/quantize.py) does not produce satisfactory accuracy with W4A8 AWQ. This has nothing to do with GPU arch, because ModelOpt is only simulating W4A8 computation. The same thing will happen even if you change GPU to H20.

To improve accuracy of W4A8, one work-around you can try is mixed-precision quantization: make some layers fallback to FP8 and the left remain W4A8. Typically speaking, mlp.fc2 is the more quantization-sensitive and needs to be in FP8.

Here is a tool to do mixed-precision quantization by our NVIDIA collegues: https://github.com/Jackch-NV/TRTLLM-w4afp8-fp8-mix-inference. It runs both W4A8 and FP8 quantization, and then you can merge the two checkpoints by the mixed strategy you want, i.e., which is W4A8 and which is FP8.