Closed beegerous closed 8 months ago
Hi @beegerous , thank you for reporting the issue. When using https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/cutlass_kernels/int8_gemm/int8_gemm_template.h one should specify ThreadblockShape
, WarpShape
and Stages
to configure GEMM. Choice of these parameters can influence the performance by a lot. Meanwhile, torch uses some runtime heuristics (e.g. in cuDNN library) to choose the optimal GEMM shape for the given problem size. Thus, I guess, the performance difference you see comes from the comparison of suboptimal int8 GEMM to the optimal fp16 GEMM.
In TensorRT-LLM we use profiler to find the best GEMM configuration for int8 given problem size: https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/plugins/common/gemmPluginProfiler.cpp. For example, this profiler is used here, could you, please, try it out. Let me know if you have more questions.
Hi @beegerous , thank you for reporting the issue. When using https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/cutlass_kernels/int8_gemm/int8_gemm_template.h one should specify
ThreadblockShape
,WarpShape
andStages
to configure GEMM. Choice of these parameters can influence the performance by a lot. Meanwhile, torch uses some runtime heuristics (e.g. in cuDNN library) to choose the optimal GEMM shape for the given problem size. Thus, I guess, the performance difference you see comes from the comparison of suboptimal int8 GEMM to the optimal fp16 GEMM.In TensorRT-LLM we use profiler to find the best GEMM configuration for int8 given problem size: https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/plugins/common/gemmPluginProfiler.cpp. For example, this profiler is used here, could you, please, try it out. Let me know if you have more questions.
@nekorobov Thanks for reply.
sorry I miss some information. Before I calculate gemm, I use this function (https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp#L196) to get the best tile config, and make sure it only run one time for specific [m,n,k].
Is there difference between cutlass_heuristic and gemmPluginProfiler ? From io variable I think they are both for get best CutlassGemmConfig.
By the way, the code here (https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/cutlass_kernels/int8_gemm/int8_gemm_template.h#L268) is a little bit confused. It is the only one shape that diff with case variable name tells. But I change it into 64 and re-run those cases, get no difference performance.
I use this function (https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp#L196) to get the best tile config, and make sure it only run one time for specific [m,n,k].
Do you include this function into your time calculations? This might take significant amount of time. Which candidate_configs
do you provide?
Is there difference between cutlass_heuristic and gemmPluginProfiler ? From io variable I think they are both for get best CutlassGemmConfig.
Yes, cutlass_heuristic uses some estimation heuristics to find the best config without running it. While it is relatively fast estimation method, it is not always accurate. gemmPluginProfiler simply profiles each GEMM config to choose the fastest. gemmPluginProfiler should be executed offline, but it gives more accurate results.
It is the only one shape that diff with case variable name tells. But I change it into 64 and re-run those cases, get no difference performance.
Indeed, this mismatch is a bug. Thank you for reporting this. I believe that it makes difference, just not on the cases you've profiled.
@nekorobov thanks for your advise. I change cutlass_heuristic into profiler, and int8 gemm on a100 become much faster. here is the latest result. (trt-llm cost does not contain profiler time.)
m | n | k | torch cost(s) | trt-llm cost(s) |
---|---|---|---|---|
1024 | 1024 | 1024 | 19.89363145828247 | 20.7686026096344 |
2048 | 2048 | 2048 | 8.329726219177246 | 7.037260055541992 |
4096 | 4096 | 4096 | 63.426324129104614 | 44.36107349395752 |
8192 | 4096 | 4096 | 125.48689246177673 | 96.80497765541077 |
4096 | 8192 | 4096 | 126.33650040626526 | 85.49396276473999 |
4096 | 4096 | 8192 | 118.56665873527527 | 86.97948837280273 |
@nekorobov thanks for your advise. I change cutlass_heuristic into profiler, and int8 gemm on a100 become much faster. here is the latest result. (trt-llm cost does not contain profiler time.)
m n k torch cost(s) trt-llm cost(s) 1024 1024 1024 19.89363145828247 20.7686026096344 2048 2048 2048 8.329726219177246 7.037260055541992 4096 4096 4096 63.426324129104614 44.36107349395752 8192 4096 4096 125.48689246177673 96.80497765541077 4096 8192 4096 126.33650040626526 85.49396276473999 4096 4096 8192 118.56665873527527 86.97948837280273
Have you misstaken the unit? us not seconds?
I need a python operator that support int8gemm with pertoken/perchannel quantization. So I wrap the code https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/cutlass_kernels/int8_gemm/int8_gemm_template.h into something like https://github.com/Guangxuan-Xiao/torch-int/blob/main/torch_int/kernels/linear.cu#L476
Then I test the speed of int8gemm compare with pure fp16 (torch.nn.functional.linear). The time cost only statistic gemm, no quantize. I repeat gemm 1e5 times (1e6 times in small case).
test code is something like:
At first I think the reason is input tensor too small, so when mnk equals 4096, int8gemm finally faster than torch. But then I try 8192 in just one dim, the int8gemm is slower again. The last three case should have similar computations, the torch result reflact that, but int8gemm cost is quiet unstable. And i expect that int8gemm should be 2x faster than fp16 according to https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/nvidia-a100-datasheet-us-nvidia-1758950-r4-web.pdf .
I check case [4096, 8192, 4096] with nsys, didn't found anything weird.
I'm confused with these results and try to understand the deep reason. or if it is caused by some compile error, how would I check it ?
Env: Ubuntu 20.04.6 LTS NVIDIA A100-SXM4-80GB base commit: c89653021e66ca78c55f02b366f404455bc12e8d
I build and run code in docker build from docker/Dockerfile.multi . I build mylib base on scripts/build_wheel.py .