Closed rosario-purple closed 3 weeks ago
Did you test on top of latest Pytorch Triton pin update https://github.com/pytorch/pytorch/pull/126098 ?
@htyu That works! Recording the new benchmarks here for posterity (with tma_persistent=True):
bf16: shape (6144, 8192, 8192) tflops 754.66 ms 1.093
fp8 scale + row gemm: shape (6144, 8192, 8192) tflops 1102.14 ms 0.748
fp8 scale + block gemm: shape (6144, 8192, 8192) tflops 631.73 ms 1.305
fp8 row gemm only | fp8_fast_accum=True: shape (6144, 8192, 8192) tflops 1358.88 ms 0.607
fp8 block gemm only: shape (6144, 8192, 8192) tflops 926.62 ms 0.890
bf16: shape (6144, 8192, 64) tflops 133.31 ms 0.048
fp8 scale + row gemm: shape (6144, 8192, 64) tflops 36.19 ms 0.178
fp8 scale + block gemm: shape (6144, 8192, 64) tflops 57.61 ms 0.112
fp8 row gemm only | fp8_fast_accum=True: shape (6144, 8192, 64) tflops 131.83 ms 0.049
fp8 block gemm only: shape (6144, 8192, 64) tflops 152.69 ms 0.042
bf16: shape (6144, 64, 8192) tflops 124.41 ms 0.052
fp8 scale + row gemm: shape (6144, 64, 8192) tflops 28.65 ms 0.225
fp8 scale + block gemm: shape (6144, 64, 8192) tflops 21.41 ms 0.301
fp8 row gemm only | fp8_fast_accum=True: shape (6144, 64, 8192) tflops 172.87 ms 0.037
fp8 block gemm only: shape (6144, 64, 8192) tflops 110.30 ms 0.058
bf16: shape (6144, 10240, 8192) tflops 674.44 ms 1.528
fp8 scale + row gemm: shape (6144, 10240, 8192) tflops 1011.41 ms 1.019
fp8 scale + block gemm: shape (6144, 10240, 8192) tflops 641.79 ms 1.606
fp8 row gemm only | fp8_fast_accum=True: shape (6144, 10240, 8192) tflops 1450.83 ms 0.710
fp8 block gemm only: shape (6144, 10240, 8192) tflops 911.98 ms 1.130
bf16: shape (6144, 8192, 28672) tflops 665.81 ms 4.335
fp8 scale + row gemm: shape (6144, 8192, 28672) tflops 1202.57 ms 2.400
fp8 scale + block gemm: shape (6144, 8192, 28672) tflops 645.28 ms 4.473
fp8 row gemm only | fp8_fast_accum=True: shape (6144, 8192, 28672) tflops 1443.30 ms 2.000
fp8 block gemm only: shape (6144, 8192, 28672) tflops 981.57 ms 2.940
bf16: shape (6144, 57344, 8192) tflops 629.60 ms 9.168
fp8 scale + row gemm: shape (6144, 57344, 8192) tflops 1150.80 ms 5.016
fp8 scale + block gemm: shape (6144, 57344, 8192) tflops 647.45 ms 8.916
fp8 row gemm only | fp8_fast_accum=True: shape (6144, 57344, 8192) tflops 1345.78 ms 4.289
fp8 block gemm only: shape (6144, 57344, 8192) tflops 960.10 ms 6.012
Updating triton pinned in FBGEMM OSS accordingly - https://github.com/pytorch/FBGEMM/pull/2775
This PR: https://github.com/pytorch/FBGEMM/pull/2735
appears to make FP8 Triton matmuls much slower
Previous benchmarks (using the version from June 11th, on an H100):
Benchmarks on
main
: