pytorch / FBGEMM

FB (Facebook) + GEMM (General Matrix-Matrix Multiplication) - https://code.fb.com/ml-applications/fbgemm/
Other
1.12k stars 451 forks source link

Regression: Persistent kernels make Triton FP8 matmul much slower #2766

Closed rosario-purple closed 3 weeks ago

rosario-purple commented 3 weeks ago

This PR: https://github.com/pytorch/FBGEMM/pull/2735

appears to make FP8 Triton matmuls much slower

Previous benchmarks (using the version from June 11th, on an H100):

bf16:                                           shape (6144, 8192, 8192)        tflops 645.72   ms 1.277
fp8 scale + row gemm:                           shape (6144, 8192, 8192)        tflops 1012.03  ms 0.815
fp8 scale + block gemm:                         shape (6144, 8192, 8192)        tflops 593.45   ms 1.390
fp8 row gemm only | fp8_fast_accum=True:        shape (6144, 8192, 8192)        tflops 1216.34  ms 0.678
fp8 block gemm only:                            shape (6144, 8192, 8192)        tflops 941.71   ms 0.876
bf16:                                           shape (6144, 8192, 64)          tflops 146.90   ms 0.044
fp8 scale + row gemm:                           shape (6144, 8192, 64)          tflops 47.17    ms 0.137
fp8 scale + block gemm:                         shape (6144, 8192, 64)          tflops 63.33    ms 0.102
fp8 row gemm only | fp8_fast_accum=True:        shape (6144, 8192, 64)          tflops 53.19    ms 0.121
fp8 block gemm only:                            shape (6144, 8192, 64)          tflops 137.41   ms 0.047
bf16:                                           shape (6144, 64, 8192)          tflops 166.93   ms 0.039
fp8 scale + row gemm:                           shape (6144, 64, 8192)          tflops 39.50    ms 0.163
fp8 scale + block gemm:                         shape (6144, 64, 8192)          tflops 14.28    ms 0.451
fp8 row gemm only | fp8_fast_accum=True:        shape (6144, 64, 8192)          tflops 67.76    ms 0.095
fp8 block gemm only:                            shape (6144, 64, 8192)          tflops 35.58    ms 0.181
bf16:                                           shape (6144, 10240, 8192)       tflops 641.48   ms 1.607
fp8 scale + row gemm:                           shape (6144, 10240, 8192)       tflops 1026.18  ms 1.004
fp8 scale + block gemm:                         shape (6144, 10240, 8192)       tflops 608.21   ms 1.695
fp8 row gemm only | fp8_fast_accum=True:        shape (6144, 10240, 8192)       tflops 1218.73  ms 0.846
fp8 block gemm only:                            shape (6144, 10240, 8192)       tflops 903.40   ms 1.141
bf16:                                           shape (6144, 8192, 28672)       tflops 671.06   ms 4.301
fp8 scale + row gemm:                           shape (6144, 8192, 28672)       tflops 1015.83  ms 2.841
fp8 scale + block gemm:                         shape (6144, 8192, 28672)       tflops 599.41   ms 4.815
fp8 row gemm only | fp8_fast_accum=True:        shape (6144, 8192, 28672)       tflops 1280.55  ms 2.254
fp8 block gemm only:                            shape (6144, 8192, 28672)       tflops 924.42   ms 3.122
bf16:                                           shape (6144, 57344, 8192)       tflops 632.31   ms 9.129
fp8 scale + row gemm:                           shape (6144, 57344, 8192)       tflops 1078.60  ms 5.352
fp8 scale + block gemm:                         shape (6144, 57344, 8192)       tflops 682.06   ms 8.463
fp8 row gemm only | fp8_fast_accum=True:        shape (6144, 57344, 8192)       tflops 1224.64  ms 4.714
fp8 block gemm only:                            shape (6144, 57344, 8192)       tflops 907.32   ms 6.362

Benchmarks on main:

bf16:                                           shape (6144, 8192, 8192)        tflops 647.22   ms 1.274
fp8 scale + row gemm:                           shape (6144, 8192, 8192)        tflops 762.71   ms 1.081
fp8 scale + block gemm:                         shape (6144, 8192, 8192)        tflops 592.08   ms 1.393
fp8 row gemm only | fp8_fast_accum=True:        shape (6144, 8192, 8192)        tflops 881.00   ms 0.936
fp8 block gemm only:                            shape (6144, 8192, 8192)        tflops 906.05   ms 0.910
bf16:                                           shape (6144, 8192, 64)          tflops 146.66   ms 0.044
fp8 scale + row gemm:                           shape (6144, 8192, 64)          tflops 84.85    ms 0.076
fp8 scale + block gemm:                         shape (6144, 8192, 64)          tflops 70.36    ms 0.092
fp8 row gemm only | fp8_fast_accum=True:        shape (6144, 8192, 64)          tflops 106.64   ms 0.060
fp8 block gemm only:                            shape (6144, 8192, 64)          tflops 167.02   ms 0.039
bf16:                                           shape (6144, 64, 8192)          tflops 166.99   ms 0.039
fp8 scale + row gemm:                           shape (6144, 64, 8192)          tflops 28.07    ms 0.230
fp8 scale + block gemm:                         shape (6144, 64, 8192)          tflops 16.17    ms 0.398
fp8 row gemm only | fp8_fast_accum=True:        shape (6144, 64, 8192)          tflops 40.56    ms 0.159
fp8 block gemm only:                            shape (6144, 64, 8192)          tflops 49.73    ms 0.130
bf16:                                           shape (6144, 10240, 8192)       tflops 640.34   ms 1.610
fp8 scale + row gemm:                           shape (6144, 10240, 8192)       tflops 767.56   ms 1.343
fp8 scale + block gemm:                         shape (6144, 10240, 8192)       tflops 609.28   ms 1.692
fp8 row gemm only | fp8_fast_accum=True:        shape (6144, 10240, 8192)       tflops 883.67   ms 1.166
fp8 block gemm only:                            shape (6144, 10240, 8192)       tflops 902.12   ms 1.143
bf16:                                           shape (6144, 8192, 28672)       tflops 671.83   ms 4.296
fp8 scale + row gemm:                           shape (6144, 8192, 28672)       tflops 747.93   ms 3.859
fp8 scale + block gemm:                         shape (6144, 8192, 28672)       tflops 601.42   ms 4.799
fp8 row gemm only | fp8_fast_accum=True:        shape (6144, 8192, 28672)       tflops 894.26   ms 3.228
fp8 block gemm only:                            shape (6144, 8192, 28672)       tflops 927.60   ms 3.111
bf16:                                           shape (6144, 57344, 8192)       tflops 637.50   ms 9.055
fp8 scale + row gemm:                           shape (6144, 57344, 8192)       tflops 807.32   ms 7.150
fp8 scale + block gemm:                         shape (6144, 57344, 8192)       tflops 689.59   ms 8.371
fp8 row gemm only | fp8_fast_accum=True:        shape (6144, 57344, 8192)       tflops 884.92   ms 6.523
fp8 block gemm only:                            shape (6144, 57344, 8192)       tflops 920.35   ms 6.272
htyu commented 3 weeks ago

Did you test on top of latest Pytorch Triton pin update https://github.com/pytorch/pytorch/pull/126098 ?

rosario-purple commented 3 weeks ago

@htyu That works! Recording the new benchmarks here for posterity (with tma_persistent=True):

bf16:                                           shape (6144, 8192, 8192)        tflops 754.66   ms 1.093
fp8 scale + row gemm:                           shape (6144, 8192, 8192)        tflops 1102.14  ms 0.748
fp8 scale + block gemm:                         shape (6144, 8192, 8192)        tflops 631.73   ms 1.305
fp8 row gemm only | fp8_fast_accum=True:        shape (6144, 8192, 8192)        tflops 1358.88  ms 0.607
fp8 block gemm only:                            shape (6144, 8192, 8192)        tflops 926.62   ms 0.890
bf16:                                           shape (6144, 8192, 64)          tflops 133.31   ms 0.048
fp8 scale + row gemm:                           shape (6144, 8192, 64)          tflops 36.19    ms 0.178
fp8 scale + block gemm:                         shape (6144, 8192, 64)          tflops 57.61    ms 0.112
fp8 row gemm only | fp8_fast_accum=True:        shape (6144, 8192, 64)          tflops 131.83   ms 0.049
fp8 block gemm only:                            shape (6144, 8192, 64)          tflops 152.69   ms 0.042
bf16:                                           shape (6144, 64, 8192)          tflops 124.41   ms 0.052
fp8 scale + row gemm:                           shape (6144, 64, 8192)          tflops 28.65    ms 0.225
fp8 scale + block gemm:                         shape (6144, 64, 8192)          tflops 21.41    ms 0.301
fp8 row gemm only | fp8_fast_accum=True:        shape (6144, 64, 8192)          tflops 172.87   ms 0.037
fp8 block gemm only:                            shape (6144, 64, 8192)          tflops 110.30   ms 0.058
bf16:                                           shape (6144, 10240, 8192)       tflops 674.44   ms 1.528
fp8 scale + row gemm:                           shape (6144, 10240, 8192)       tflops 1011.41  ms 1.019
fp8 scale + block gemm:                         shape (6144, 10240, 8192)       tflops 641.79   ms 1.606
fp8 row gemm only | fp8_fast_accum=True:        shape (6144, 10240, 8192)       tflops 1450.83  ms 0.710
fp8 block gemm only:                            shape (6144, 10240, 8192)       tflops 911.98   ms 1.130
bf16:                                           shape (6144, 8192, 28672)       tflops 665.81   ms 4.335
fp8 scale + row gemm:                           shape (6144, 8192, 28672)       tflops 1202.57  ms 2.400
fp8 scale + block gemm:                         shape (6144, 8192, 28672)       tflops 645.28   ms 4.473
fp8 row gemm only | fp8_fast_accum=True:        shape (6144, 8192, 28672)       tflops 1443.30  ms 2.000
fp8 block gemm only:                            shape (6144, 8192, 28672)       tflops 981.57   ms 2.940
bf16:                                           shape (6144, 57344, 8192)       tflops 629.60   ms 9.168
fp8 scale + row gemm:                           shape (6144, 57344, 8192)       tflops 1150.80  ms 5.016
fp8 scale + block gemm:                         shape (6144, 57344, 8192)       tflops 647.45   ms 8.916
fp8 row gemm only | fp8_fast_accum=True:        shape (6144, 57344, 8192)       tflops 1345.78  ms 4.289
fp8 block gemm only:                            shape (6144, 57344, 8192)       tflops 960.10   ms 6.012
q10 commented 3 weeks ago

Updating triton pinned in FBGEMM OSS accordingly - https://github.com/pytorch/FBGEMM/pull/2775