NVIDIA / TensorRT-LLM

TensorRT-LLM provides users with an easy-to-use Python API to define Large Language Models (LLMs) and build TensorRT engines that contain state-of-the-art optimizations to perform inference efficiently on NVIDIA GPUs. TensorRT-LLM also contains components to create Python and C++ runtimes that execute those TensorRT engines.
https://nvidia.github.io/TensorRT-LLM
Apache License 2.0
8.31k stars 931 forks source link

int8 gemm slower than fp16 on A100. #935

Closed beegerous closed 8 months ago

beegerous commented 8 months ago

I need a python operator that support int8gemm with pertoken/perchannel quantization. So I wrap the code https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/cutlass_kernels/int8_gemm/int8_gemm_template.h into something like https://github.com/Guangxuan-Xiao/torch-int/blob/main/torch_int/kernels/linear.cu#L476

Then I test the speed of int8gemm compare with pure fp16 (torch.nn.functional.linear). The time cost only statistic gemm, no quantize. I repeat gemm 1e5 times (1e6 times in small case).

M N K torch cost(s) trt-llm cost(s)
1024 1024 1024 19.999567985534668 22.63555335998535
2048 2048 2048 8.349320888519287 7.4798314571380615
4096 4096 4096 65.3261570930481 45.53051781654358
8192 4096 4096 125.70239543914795 137.0772671699524
4096 8192 4096 125.74432516098022 117.87490010261536
4096 4096 8192 118.52623224258423 86.75222182273865

test code is something like:

x, alpha = gen_input_tensor([M, K])
y, beta  = gen_input_tensor([N, K])
n = 100000

with cost("int8gemm"):
    for _ in range(n):
        d = mylib.linear_a8_w8_bofp16(x, y, beta, alpha, bias)

x = x * alpha
y = y * beta

with cost("torch"):
    for _ in range(n):
        c = torch.nn.functional.linear(x, y)

At first I think the reason is input tensor too small, so when mnk equals 4096, int8gemm finally faster than torch. But then I try 8192 in just one dim, the int8gemm is slower again. The last three case should have similar computations, the torch result reflact that, but int8gemm cost is quiet unstable. And i expect that int8gemm should be 2x faster than fp16 according to https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/nvidia-a100-datasheet-us-nvidia-1758950-r4-web.pdf .

image I check case [4096, 8192, 4096] with nsys, didn't found anything weird.

I'm confused with these results and try to understand the deep reason. or if it is caused by some compile error, how would I check it ?

Env: Ubuntu 20.04.6 LTS NVIDIA A100-SXM4-80GB base commit: c89653021e66ca78c55f02b366f404455bc12e8d

I build and run code in docker build from docker/Dockerfile.multi . I build mylib base on scripts/build_wheel.py .

nekorobov commented 8 months ago

Hi @beegerous , thank you for reporting the issue. When using https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/cutlass_kernels/int8_gemm/int8_gemm_template.h one should specify ThreadblockShape, WarpShape and Stages to configure GEMM. Choice of these parameters can influence the performance by a lot. Meanwhile, torch uses some runtime heuristics (e.g. in cuDNN library) to choose the optimal GEMM shape for the given problem size. Thus, I guess, the performance difference you see comes from the comparison of suboptimal int8 GEMM to the optimal fp16 GEMM.

In TensorRT-LLM we use profiler to find the best GEMM configuration for int8 given problem size: https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/plugins/common/gemmPluginProfiler.cpp. For example, this profiler is used here, could you, please, try it out. Let me know if you have more questions.

beegerous commented 8 months ago

Hi @beegerous , thank you for reporting the issue. When using https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/cutlass_kernels/int8_gemm/int8_gemm_template.h one should specify ThreadblockShape, WarpShape and Stages to configure GEMM. Choice of these parameters can influence the performance by a lot. Meanwhile, torch uses some runtime heuristics (e.g. in cuDNN library) to choose the optimal GEMM shape for the given problem size. Thus, I guess, the performance difference you see comes from the comparison of suboptimal int8 GEMM to the optimal fp16 GEMM.

In TensorRT-LLM we use profiler to find the best GEMM configuration for int8 given problem size: https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/plugins/common/gemmPluginProfiler.cpp. For example, this profiler is used here, could you, please, try it out. Let me know if you have more questions.

@nekorobov Thanks for reply.

sorry I miss some information. Before I calculate gemm, I use this function (https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp#L196) to get the best tile config, and make sure it only run one time for specific [m,n,k].

Is there difference between cutlass_heuristic and gemmPluginProfiler ? From io variable I think they are both for get best CutlassGemmConfig.

By the way, the code here (https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/cutlass_kernels/int8_gemm/int8_gemm_template.h#L268) is a little bit confused. It is the only one shape that diff with case variable name tells. But I change it into 64 and re-run those cases, get no difference performance.

nekorobov commented 8 months ago

I use this function (https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp#L196) to get the best tile config, and make sure it only run one time for specific [m,n,k].

Do you include this function into your time calculations? This might take significant amount of time. Which candidate_configs do you provide?

Is there difference between cutlass_heuristic and gemmPluginProfiler ? From io variable I think they are both for get best CutlassGemmConfig.

Yes, cutlass_heuristic uses some estimation heuristics to find the best config without running it. While it is relatively fast estimation method, it is not always accurate. gemmPluginProfiler simply profiles each GEMM config to choose the fastest. gemmPluginProfiler should be executed offline, but it gives more accurate results.

It is the only one shape that diff with case variable name tells. But I change it into 64 and re-run those cases, get no difference performance.

Indeed, this mismatch is a bug. Thank you for reporting this. I believe that it makes difference, just not on the cases you've profiled.

beegerous commented 8 months ago

@nekorobov thanks for your advise. I change cutlass_heuristic into profiler, and int8 gemm on a100 become much faster. here is the latest result. (trt-llm cost does not contain profiler time.)

m n k torch cost(s) trt-llm cost(s)
1024 1024 1024 19.89363145828247 20.7686026096344
2048 2048 2048 8.329726219177246 7.037260055541992
4096 4096 4096 63.426324129104614 44.36107349395752
8192 4096 4096 125.48689246177673 96.80497765541077
4096 8192 4096 126.33650040626526 85.49396276473999
4096 4096 8192 118.56665873527527 86.97948837280273
foreverlms commented 4 days ago

@nekorobov thanks for your advise. I change cutlass_heuristic into profiler, and int8 gemm on a100 become much faster. here is the latest result. (trt-llm cost does not contain profiler time.)

m n k torch cost(s) trt-llm cost(s) 1024 1024 1024 19.89363145828247 20.7686026096344 2048 2048 2048 8.329726219177246 7.037260055541992 4096 4096 4096 63.426324129104614 44.36107349395752 8192 4096 4096 125.48689246177673 96.80497765541077 4096 8192 4096 126.33650040626526 85.49396276473999 4096 4096 8192 118.56665873527527 86.97948837280273

Have you misstaken the unit? us not seconds?