FlagOpen / FlagGems

FlagGems is an operator library for large language models implemented in Triton Language.
Apache License 2.0
347 stars 48 forks source link

[Operator] Codegen scatter&gather #226

Closed GwokHiujin closed 1 month ago

GwokHiujin commented 2 months ago

In this submission, we used code generation to handle scatter & gather with different ranks. Performance testing results show that this change brings the performance of these two operators up to a level comparable to Torch :)

Some of the performance test results are as follows:

scatter_perf

test_reduction_perf.py::test_perf_scatter 
Operator scatter Performance Test (dtype=torch.float16, mode=cuda)
Size    Torch Latency (ms)    Gems Latency (ms)    Gems Speedup
---------------------------------------------------------------
1024              0.017408             0.018432           0.944
6144              0.032768             0.032768             1.0
11264             0.048128               0.0512            0.94
16384             0.062464             0.065536           0.953
21504             0.075776             0.077824           0.974
26624             0.089088             0.090112           0.989
31744             0.103424             0.105472           0.981
36864             0.116736             0.118784           0.983
41984             0.130048             0.131072           0.992
47104              0.14336             0.144384           0.993
52224             0.157696             0.161792           0.975
57344             0.172032             0.175104           0.982
62464             0.186368             0.192512           0.968
67584             0.197632             0.198656           0.995
72704             0.212992             0.216064           0.986
77824             0.224256              0.22528           0.995
Operator scatter Performance Test (dtype=torch.float32, mode=cuda)
Size    Torch Latency (ms)    Gems Latency (ms)    Gems Speedup
---------------------------------------------------------------
1024              0.017408              0.02048            0.85
6144                0.0512               0.0512             1.0
11264             0.077824             0.077824             1.0
16384             0.104448             0.105472            0.99
21504              0.13312              0.13312             1.0
26624             0.159744             0.159744             1.0
31744              0.18944             0.192512           0.984
36864             0.214016             0.214016             1.0
41984             0.242688             0.244736           0.992
47104             0.268288             0.269312           0.996
52224              0.29696             0.299008           0.993
57344             0.325632              0.32768           0.994
62464             0.350208             0.349184             1.0
67584             0.376832             0.377856           0.997
72704             0.411648             0.418816           0.983
77824              0.43008             0.427008            1.01
Operator scatter Performance Test (dtype=torch.bfloat16, mode=cuda)
Size    Torch Latency (ms)    Gems Latency (ms)    Gems Speedup
---------------------------------------------------------------
1024              0.017408             0.018432           0.944
6144              0.032768             0.033792            0.97
11264             0.048128             0.050176           0.959
16384             0.062464             0.064512           0.968
21504             0.075776             0.078848           0.961
26624             0.090112             0.093184           0.967
31744             0.103424             0.106496           0.971
36864             0.116736             0.119808           0.974
41984             0.129024             0.124928            1.03
47104              0.14336             0.145408           0.986
52224             0.156672              0.15872           0.987
57344             0.172032             0.175104           0.982
62464             0.183296             0.185344           0.989
67584             0.197632              0.19968            0.99
72704             0.211968             0.214016            0.99
77824             0.227328             0.232448           0.978
PASSED

gather_perf

test_reduction_perf.py::test_perf_gather 
Operator gather Performance Test (dtype=torch.float16, mode=cuda)
Size    Torch Latency (ms)    Gems Latency (ms)    Gems Speedup
---------------------------------------------------------------
1024              0.011264             0.012288           0.917
6144              0.011264             0.008192            1.38
11264             0.011264             0.009216            1.22
16384             0.013312             0.022528           0.591
21504             0.011264              0.01024             1.1
26624             0.018432             0.034816           0.529
31744             0.031744             0.074752           0.425
36864             0.012288             0.016384            0.75
41984             0.016384             0.031744           0.516
47104             0.013312              0.02048            0.65
52224             0.012288             0.012288             1.0
57344             0.011264             0.013312           0.846
62464              0.02048             0.038912           0.526
67584             0.013312             0.024576           0.542
72704             0.011264             0.009216            1.22
77824              0.01024             0.008192            1.25
Operator gather Performance Test (dtype=torch.float32, mode=cuda)
Size    Torch Latency (ms)    Gems Latency (ms)    Gems Speedup
---------------------------------------------------------------
1024              0.011264              0.01024             1.1
6144              0.012288             0.014336           0.857
11264             0.013312             0.017408           0.765
16384             0.011264             0.007168            1.57
21504             0.011264             0.008192            1.38
26624             0.012288             0.013312           0.923
31744             0.018432             0.036864             0.5
36864             0.011264              0.01024             1.1
41984             0.017408             0.026624           0.654
47104             0.017408             0.026624           0.654
52224             0.016384              0.03072           0.533
57344             0.011264             0.007168            1.57
62464             0.014336             0.017408           0.824
67584             0.032768             0.059392           0.552
72704             0.018432             0.027648           0.667
77824             0.021504             0.034816           0.618
Operator gather Performance Test (dtype=torch.bfloat16, mode=cuda)
Size    Torch Latency (ms)    Gems Latency (ms)    Gems Speedup
---------------------------------------------------------------
1024              0.011264             0.012288           0.917
6144              0.011264             0.008192            1.38
11264             0.012288              0.01024             1.2
16384             0.013312             0.023552           0.565
21504             0.012288             0.014336           0.857
26624             0.011264             0.007168            1.57
31744             0.013312             0.014336           0.929
36864             0.013312             0.026624             0.5
41984              0.01024             0.008192            1.25
47104             0.013312             0.021504           0.619
52224             0.021504             0.039936           0.538
57344             0.013312             0.016384           0.812
62464             0.011264             0.011264             1.0
67584             0.027648               0.0768            0.36
72704             0.023552              0.04096           0.575
77824             0.014336             0.019456           0.737
PASSED

If this branch can be merged, we can later use scatter for implementing select_scatter and slice_scatter, and gather for index_select and nll_loss. These implementations have been validated on a private development branch, demonstrating performance improvements that approach those of Torch, unlike the previously poor results.

Any advice is welcome.

tongxin commented 2 months ago

36864 0.012288 0.016384 0.75

Thanks, Xiaoyan! That was an extraordinary improvement.