FlagOpen / FlagGems

FlagGems is an operator library for large language models implemented in Triton Language.
Apache License 2.0
296 stars 27 forks source link

Floordiv int #204

Closed tongxin closed 1 month ago

tongxin commented 1 month ago

PR Category

Op

Type of Change

Bug Fix and new feature

Description

Added the Pytorch/numpy typed remainder and floor divide for integers. The summarized procedure for the Triton Pytorch conversion for floordiv is:

def floordiv(x, y):
    if x % y != 0 and x < 0 ^ y < 0:
        return x // y - 1
    else:
        return x // y

def _remainder(x, y):
    if x % y != 0 and x < 0 ^ y < 0:
        return r + y
    else:
        return r

Issue

Progress

Performance

Performance for bs=1024 on A100:

benchmark/test_pointwise_perf.py Operator floor_div Performance Test (torch.int16)
Size        Torch Latency (ms)   Gems Latency (ms)
--------------------------------------------------
1024                  0.016096            0.014752
6144                  0.042016            0.040128
11264                 0.067136            0.066112
16384                 0.089888            0.089728
21504                 0.114112            0.113856
26624                 0.136992            0.136608
31744                 0.160384            0.160576
36864                 0.184832            0.254016
41984                  0.20832            0.264128
47104                   0.2312            0.273408
52224                 0.254592            0.284448
57344                 0.278144            0.293824
62464                 0.302272            0.303968
67584                 0.325568            0.368288
72704                 0.349024            0.383584
77824                   0.3728            0.397408
Operator floor_div Performance Test (torch.int32)
Size        Torch Latency (ms)   Gems Latency (ms)
--------------------------------------------------
1024                  0.018592            0.018432
6144                  0.068288            0.067712
11264                 0.113728            0.114752
16384                  0.15904            0.159264
21504                  0.20624              0.2064
26624                 0.252416            0.251648
31744                 0.297088             0.29792
36864                 0.343968            0.344928
41984                  0.38992            0.390944
47104                 0.436352            0.435488
52224                   0.4832            0.481312
57344                  0.52992            0.526112
62464                 0.576384            0.571968
67584                  0.62144            0.617504
72704                 0.667744            0.662816
77824                 0.714848            0.707328

benchmark/test_pointwise_perf.py Operator remainder Performance Test (torch.int16)
Size        Torch Latency (ms)   Gems Latency (ms)
--------------------------------------------------
1024                   0.01504            0.013792
6144                   0.03936            0.040768
11264                   0.0656            0.066432
16384                  0.08928            0.089376
21504                  0.11216            0.112736
26624                 0.134464            0.136032
31744                 0.158048            0.159904
36864                  0.18032            0.248384
41984                 0.203264            0.259488
47104                 0.226496            0.269952
52224                 0.249952             0.28176
57344                 0.272832            0.291872
62464                 0.295104            0.302176
67584                 0.318432             0.36064
72704                 0.340352             0.37616
77824                 0.364448            0.391904
Operator remainder Performance Test (torch.int32)
Size        Torch Latency (ms)   Gems Latency (ms)
--------------------------------------------------
1024                  0.018112            0.018304
6144                  0.066912            0.067456
11264                 0.112992            0.113664
16384                 0.158304             0.15904
21504                 0.205536            0.206144
26624                 0.250944            0.251488
31744                 0.296448             0.29776
36864                 0.343392            0.344064
41984                 0.389312            0.390112
47104                 0.436384            0.436096
52224                 0.481824             0.48032
57344                  0.52912             0.52592
62464                  0.57488             0.57184
67584                 0.620736            0.617632
72704                 0.667008            0.662144
77824                 0.714208             0.70736