Open sryap opened 11 months ago
Thanks for sharing the analysis. The performance seem lower than what I had measured on my H100 as I was able to get over 1.3PF but we haven't enabled perf regression testing on H100 so maybe we did regress. What commit did you use?
I realize now that fp8_fast_accum=False
is not implemented efficiently. It is meant to be equivalent to CUBLASLT_MATMUL_DESC_FAST_ACCUM
however since the implementation of this knob is not public we don't know if there are equivalent.
I'll work on adding fp8 perf tests in the future and I can confirm the perf I'm getting at the time.
@ThomasRaoux Thank you for your response.
I was able to get over 1.3PF but we haven't enabled perf regression testing on H100 so maybe we did regress.
For the numbers that we shared, we limited the power to 500W. Perhaps, your H100 was running at 700W? For 700W, we can also see 1.3 PF. Triton performs about 10% worse than CUBLAS in this case. So, it looks like we are getting the expected performance?
I realize now that fp8_fast_accum=False is not implemented efficiently. It is meant to be equivalent to CUBLASLT_MATMUL_DESC_FAST_ACCUM however since the implementation of this knob is not public we don't know if there are equivalent.
You meant fp8_fast_accum=False
is meant to be equivalent to CUBLASLT_MATMUL_DESC_FAST_ACCUM=0
?
What commit did you use?
I used 768fc1fcd98e
For the numbers that we shared, we limited the power to 500W. Perhaps, your H100 was running at 700W? For 700W, we can also see 1.3 PF. Triton performs about 10% worse than CUBLAS in this case. So, it looks like we are getting the expected performance?
ah ok, then you are most likely getting the expected performance. Note that we haven't enabled all the performance features (like TMA, warp specialization, etc...) so it is possible that we are slower than CUBLAS due to that.
You meant fp8_fast_accum=False is meant to be equivalent to CUBLASLT_MATMUL_DESC_FAST_ACCUM=0?
right
@ThomasRaoux Thank you! Do you plan to enable all the performance features in the future?
We also ran BF16 with matrix B transposed and not transposed and observed about 10% difference in performance. Is this expected? (For cublas, the performance is the same even when matrix B is transposed)
Hey @ThomasRaoux, I was profiling the performance of triton fp8 gemm as well and came across this issue. I'm still observing the same performance degradation as above when using fp8_fast_accum=False
. The max performance is around 280 TFLOPs on an H100 without fp8 fast accum, and around 1300 TFLOPs on an H100 with fp8 fast accum enabled. Other functions that use CUBLAS like _scaled_mm
and TransformerEngine's te_gemm
see performance >1000 TFLOPs with these matrix sizes, and without fp8 fast accum. What could be the cause of the discrepancy here?
Hey @ThomasRaoux, I was profiling the performance of triton fp8 gemm as well and came across this issue. I'm still observing the same performance degradation as above when using
fp8_fast_accum=False
. The max performance is around 280 TFLOPs on an H100 without fp8 fast accum, and around 1300 TFLOPs on an H100 with fp8 fast accum enabled. Other functions that use CUBLAS like_scaled_mm
and TransformerEngine'ste_gemm
see performance >1000 TFLOPs with these matrix sizes, and without fp8 fast accum. What could be the cause of the discrepancy here?
Do you know what kind of accumulation is done for the cases you are comparing? Is it accumulating everything in fp32 or is it doing one accumulation in fp32 every a given number of additions? You'll want to make sure to do apple to apple comparisons.
The perf with fast accumulation off are slower than expected indeed. When doing one accumulation per K additions Inhad measured perf close the 1000TF, the code in our current matmul op is suboptimal as it will do the last accumulation in a selarate op. Unfortunately I don't have time to fix this. It should probably be easy to tweak the kernel a bit to get better performance by setting max_num_imprecise_acc to get the right precision
Great thanks. I'll try to take a look. As for the comparisons, at least with te_gemm
, the accumulation types were the same. I'm a bit new to triton -- it would be great if you could point me to where I can modify max_num_imprecise_acc
and the separate op for the last accum. Thanks!
Great thanks. I'll try to take a look. As for the comparisons, at least with
te_gemm
, the accumulation types were the same.
Same as what? The problem is that the internal accumulation precision for fp8 -> fp32 tensor core is lower than 32 bits.
I'm a bit new to triton -- it would be great if you could point me to where I can modify
max_num_imprecise_acc
and the separate op for the last accum. Thanks!
Here is an example: https://github.com/openai/triton/blob/main/python/test/unit/language/test_core.py#L3295
Sorry, I meant that the accumulation precision was the same for both te_gemm
and for the triton matmul kernel in my benchmarking.
Will take a look at that example, thanks!
We had some earlier discussions around FP8 accuracy. @htyu pointed out a cublas doc that may explain why the accuracy is bad for fp8_fast_accum=True https://docs.nvidia.com/cuda/cublas/ CUBLASLT_MATMUL_DESC_FAST_ACCUM Flag for managing FP8 fast accumulation mode. When enabled, problem execution might be faster but at the cost of lower accuracy because intermediate results will not periodically be promoted to a higher precision.
There is also an input to tl.dot that controls how many fp8 accumulations can be imprecise (max_num_imprecise_acc) https://triton-lang.org/main/python-api/generated/triton.language.dot.html#triton.language.dot
https://github.com/triton-lang/triton/pull/3973 should improve performance for fp8_fast_accum=False .
Hello, we have measured the FP8 GEMM performance using Triton on NVIDIA H100 (500 W, 1980 MHz). We would like to request your help in understanding if the performance is expected.
Since H100 FP8 only supports NT GEMM (matrix B is transposed), so we tested the performance with matrix B transposed and not transposed. We also tested with
fp8_fast_accum=True
andfp8_fast_accum=False
.Results overview:
fp8_fast_accum = True
has a significant positive impact on the performance.Our benchmark: fp8_matmul.py.zip
Questions:
fp8_fast_accum=True
do? Does it setCUBLASLT_MATMUL_DESC_FAST_ACCUM=1
and/or do something differently?Thank you very much!