FlagOpen / FlagGems

FlagGems is an operator library for large language models implemented in Triton Language.
Apache License 2.0
296 stars 27 forks source link

[Operator] Add masked_select op[IluvatarCorex] #198

Closed zfu82 closed 1 month ago

zfu82 commented 1 month ago

PR Category

Operator

Type of Change

New Feature

Description

Add masked_select op

Issue

Progress

Performance

benchmark/test_reduction_perf.py::test_masked_select Operator masked_select Performance Test (torch.float16)
Size        Torch Latency (ms)   Gems Latency (ms)
--------------------------------------------------
1024                  0.080896            0.229376
6144                  0.223232            0.415744
11264                 0.365568              0.4608
16384                 0.503808            0.515072
21504                 0.646144            0.596992
26624                  0.78336            0.713728
31744                 0.924672            0.830464
36864                  1.06394            0.948224
41984                  1.20218             1.06598
47104                  1.33632             1.18272
52224                  1.47661             1.30048
57344                  1.61485             1.41824
62464                  1.75411             1.53702
67584                  1.88928             1.65376
72704                  2.03264             1.77152
77824                  2.16576             1.88826
Operator masked_select Performance Test (torch.float32)
Size        Torch Latency (ms)   Gems Latency (ms)
--------------------------------------------------
1024                  0.075776            0.229376
6144                    0.2304            0.410624
11264                  0.37888            0.461824
16384                 0.524288            0.519168
21504                  0.67072            0.632832
26624                 0.817152            0.758784
31744                  0.96256            0.884736
36864                  1.10592             1.01171
41984                  1.24723             1.13766
47104                  1.39059             1.26464
52224                  1.53395             1.39162
57344                  1.68038             1.51757
62464                  1.82784             1.64352
67584                  1.97018             1.77152
72704                  2.11558             1.89747
77824                   2.2528             2.02342
Operator masked_select Performance Test (torch.bfloat16)
Size        Torch Latency (ms)   Gems Latency (ms)
--------------------------------------------------
1024                  0.075776            0.229376
6144                  0.223232            0.412672
11264                 0.364544            0.459776
16384                 0.505856            0.513024
21504                  0.64512            0.596992
26624                 0.785408            0.713728
31744                 0.923648            0.830464
36864                  1.06291            0.948224
41984                   1.1991             1.06496
47104                  1.33734             1.18272
52224                  1.47456             1.30048
57344                  1.61382             1.41824
62464                  1.75411             1.53702
67584                  1.89542             1.65478
72704                  2.03366             1.77152
77824                  2.16371             1.88826
PASSED