Closed gau-nernst closed 5 months ago
Thanks for your efforts. I am delighted that our kernel can also work well on 4070Ti GPUs, as I never tested it myself. The phenomenon you have observed is exactly what we would expect. Please refer to Figure 1 in our paper. As the BS increases to a certain point (e.g., 256 for A100), the GEMM becomes compute-bound, where the computational throughput of Tensor Cores becomes the bottleneck of the kernel execution. Our FP6-LLM kernel and cuBLAS are both using FP16 tensor cores. Thus, the theoretical peak speed of both kernels for large batch sizes should be the same. We have tried to make our FP6-LLM not significantly slower than cuBLAS for large batch sizes, but we can never be faster than cuBLAS if we also use the FP16 tensor core for the core computations. Note that the breaking point could be different for different GPUs. The breaking point for H100 would be larger than 256 since the ratio between Tensor Core throughput and DRAM bandwidth is larger than ever. Our kernel is not optimized for H100 now, but I will do it if I get more spare time.
Thank you for your reply. It's good to know that it is indeed a known limitation.
In the case of LLM, the batch size in matmul is batch size x sequence length. Thus, how would this kernel be faster than fp16 for LLM inference? Specifically, I'm referring to End2End Inference results (Section 7.3), Figure 12, 13, and 14. Is the "batch size" in Figure 12 and 13 refers to batch size of matmul (BxD), or batch size of transformer activations (BxLxD)? Due to the sentence "We set the prefill/prompt length of each request to 0.5K, and generate 1.5K tokens", I have the impression that it is the latter (BxLxD).
The shape of activations is (B*L, Hidden) for the prompt/prefill processing phase, but the shape of activation becomes (B*1, Hidden) for the decoding (token generation) phase if KV-Cache is used. Our kernel mainly accelerates the token generation phase of LLM inference. For LLM inference, we only need to execute a prompt processing phase once (all input tokens are processed in parallel), but we must execute a decoding step to generate each output token. Thus, the decoding phase can easily dominate the overall LLM inference execution time. That is the reason why we can get end-to-end speedups.
I see, that makes perfect sense! Thank you again. I will close this issue.
Another small question. Do you plan to add support for bias (A @ W.T + b) in the future? It's a bit inconvenient to launch a separate kernel just to add bias.
Thanks for your valuable suggestion. I will add this feature.
Hello,
Thank you for the great work. I'm integrating the FP6 linear kernel to torchao (https://github.com/pytorch/ao/pull/223). One thing I have observed is that the kernel is slower than PyTorch's default at large batch sizes. On my 4070Ti Super, the breaking point is at batch size = 256. You can see the detailed benchmark reports in my PR to torchao above. @msaroufim also had similar results with H100. I use the splitK values as specified in
tests/python/run.sh
Have you observed similar results? I believe the kernel was tuned for A100. I don't have access to an A100 so I can't check.
I have also run
tests/python/run.sh
with the kernel compiled in this repo. Results are below