ROCm / MIOpen

AMD's Machine Intelligence Library
https://rocm.docs.amd.com/projects/MIOpen/en/latest/
Other
1.09k stars 230 forks source link

Implement Cumulative reduction (max, min, sum, prod) forward with small last dim #3297

Open long10024070 opened 1 month ago

long10024070 commented 1 month ago

This PR is a continuation of PR #3182. Accidently, I have closed the older PR, and then made change to the working branch, which makes me cannot reopen the older once. There are not many comments in that PR, I hope it doesn't interrupt your reviewing process. And again, sorry for this Inconvenience.

float16
| op_name | dtype | size | dim | contiguous | model | direction | ROCm pytorch | MIOpen HIP | Improvement | |:-------:|:-------:|:----------------:|:---:|------------|:------:|:---------:|:------------:|:----------:|:-----------:| | CumMax | float16 | [512 64 112 112] | -1 | TRUE | random | fwd | 79103622 | 10290800 | 7.69 | | CumMax | float16 | [512 64 56 56] | -1 | TRUE | random | fwd | 39319091 | 2490330 | 15.79 | | CumMax | float16 | [512 128 56 56] | -1 | TRUE | random | fwd | 78599721 | 4982140 | 15.78 | | CumMax | float16 | [512 128 28 28] | -1 | TRUE | random | fwd | 39227767 | 2479240 | 15.82 | | CumMax | float16 | [512 256 28 28] | -1 | TRUE | random | fwd | 78414528 | 4955720 | 15.82 | | CumMax | float16 | [512 256 14 14] | -1 | TRUE | random | fwd | 39164283 | 2427920 | 16.13 | | CumMax | float16 | [512 512 14 14] | -1 | TRUE | random | fwd | 80268168 | 4854160 | 16.54 | | CumMax | float16 | [512 512 7 7] | -1 | TRUE | random | fwd | 39191305 | 2401980 | 16.32 | | CumMax | float16 | [512 1024 7 7] | -1 | TRUE | random | fwd | 78271414 | 4805250 | 16.29 | | CumMax | float16 | [512 1024 100] | -1 | TRUE | random | fwd | 11277661 | 1463220 | 7.71 | | CumMax | float16 | [1024 1024 7 7] | -1 | TRUE | random | fwd | 156666821 | 10686200 | 14.66 | | CumMax | float16 | [1024 1024 100] | -1 | TRUE | random | fwd | 22540060 | 2920460 | 7.72 | | CumMin | float16 | [512 64 112 112] | -1 | TRUE | random | fwd | 79032894 | 10293300 | 7.68 | | CumMin | float16 | [512 64 56 56] | -1 | TRUE | random | fwd | 39290595 | 2491030 | 15.77 | | CumMin | float16 | [512 128 56 56] | -1 | TRUE | random | fwd | 78578550 | 4982730 | 15.77 | | CumMin | float16 | [512 128 28 28] | -1 | TRUE | random | fwd | 39189412 | 2478940 | 15.81 | | CumMin | float16 | [512 256 28 28] | -1 | TRUE | random | fwd | 78419674 | 4956120 | 15.82 | | CumMin | float16 | [512 256 14 14] | -1 | TRUE | random | fwd | 39156197 | 2426850 | 16.13 | | CumMin | float16 | [512 512 14 14] | -1 | TRUE | random | fwd | 78311994 | 4855330 | 16.13 | | CumMin | float16 | [512 512 7 7] | -1 | TRUE | random | fwd | 39105638 | 2400610 | 16.29 | | CumMin | float16 | [512 1024 7 7] | -1 | TRUE | random | fwd | 78254683 | 4805610 | 16.28 | | CumMin | float16 | [512 1024 100] | -1 | TRUE | random | fwd | 11269600 | 1461600 | 7.71 | | CumMin | float16 | [1024 1024 7 7] | -1 | TRUE | random | fwd | 156521111 | 10696300 | 14.63 | | CumMin | float16 | [1024 1024 100] | -1 | TRUE | random | fwd | 22551889 | 5641600 | 4.00 | | CumSum | float16 | [512 64 112 112] | -1 | TRUE | random | fwd | 36839240 | 6739680 | 5.47 | | CumSum | float16 | [512 64 56 56] | -1 | TRUE | random | fwd | 18283694 | 2321070 | 7.88 | | CumSum | float16 | [512 128 56 56] | -1 | TRUE | random | fwd | 36585132 | 4639960 | 7.88 | | CumSum | float16 | [512 128 28 28] | -1 | TRUE | random | fwd | 18230703 | 2307310 | 7.90 | | CumSum | float16 | [512 256 28 28] | -1 | TRUE | random | fwd | 36477501 | 4612030 | 7.91 | | CumSum | float16 | [512 256 14 14] | -1 | TRUE | random | fwd | 18207967 | 2298780 | 7.92 | | CumSum | float16 | [512 512 14 14] | -1 | TRUE | random | fwd | 36433086 | 4594060 | 7.93 | | CumSum | float16 | [512 512 7 7] | -1 | TRUE | random | fwd | 18215727 | 2291620 | 7.95 | | CumSum | float16 | [512 1024 7 7] | -1 | TRUE | random | fwd | 36442782 | 4580620 | 7.96 | | CumSum | float16 | [512 1024 100] | -1 | TRUE | random | fwd | 5255286 | 956699 | 5.49 | | CumSum | float16 | [1024 1024 7 7] | -1 | TRUE | random | fwd | 72925500 | 9161610 | 7.96 | | CumSum | float16 | [1024 1024 100] | -1 | TRUE | random | fwd | 10510668 | 1924150 | 5.46 | | CumProd | float16 | [512 64 112 112] | -1 | TRUE | random | fwd | 36853144 | 6734100 | 5.47 | | CumProd | float16 | [512 64 56 56] | -1 | TRUE | random | fwd | 18310781 | 2320740 | 7.89 | | CumProd | float16 | [512 128 56 56] | -1 | TRUE | random | fwd | 36623723 | 4640240 | 7.89 | | CumProd | float16 | [512 128 28 28] | -1 | TRUE | random | fwd | 18271694 | 2309390 | 7.91 | | CumProd | float16 | [512 256 28 28] | -1 | TRUE | random | fwd | 36523629 | 4616960 | 7.91 | | CumProd | float16 | [512 256 14 14] | -1 | TRUE | random | fwd | 18221247 | 2301290 | 7.92 | | CumProd | float16 | [512 512 14 14] | -1 | TRUE | random | fwd | 36498109 | 4601770 | 7.93 | | CumProd | float16 | [512 512 7 7] | -1 | TRUE | random | fwd | 18246623 | 2295060 | 7.95 | | CumProd | float16 | [512 1024 7 7] | -1 | TRUE | random | fwd | 39393701 | 4588880 | 8.58 | | CumProd | float16 | [512 1024 100] | -1 | TRUE | random | fwd | 5260982 | 956966 | 5.50 | | CumProd | float16 | [1024 1024 7 7] | -1 | TRUE | random | fwd | 73033611 | 10276700 | 7.11 | | CumProd | float16 | [1024 1024 100] | -1 | TRUE | random | fwd | 10518988 | 1910770 | 5.51 |
float32
| op_name | dtype | size | dim | contiguous | model | direction | ROCm pytorch | MIOpen HIP | Improvement | |:-------:|:-------:|:----------------:|:---:|------------|:------:|:---------:|:------------:|:----------:|:-----------:| | CumMax | float32 | [512 64 112 112] | -1 | TRUE | random | fwd | 79353556 | 10510300 | 7.55 | | CumMax | float32 | [512 64 56 56] | -1 | TRUE | random | fwd | 39444502 | 2528340 | 15.60 | | CumMax | float32 | [512 128 56 56] | -1 | TRUE | random | fwd | 78924619 | 5057250 | 15.61 | | CumMax | float32 | [512 128 28 28] | -1 | TRUE | random | fwd | 39394950 | 2517180 | 15.65 | | CumMax | float32 | [512 256 28 28] | -1 | TRUE | random | fwd | 78769181 | 5027420 | 15.67 | | CumMax | float32 | [512 256 14 14] | -1 | TRUE | random | fwd | 39279320 | 2443630 | 16.07 | | CumMax | float32 | [512 512 14 14] | -1 | TRUE | random | fwd | 80072059 | 4885720 | 16.39 | | CumMax | float32 | [512 512 7 7] | -1 | TRUE | random | fwd | 39238905 | 2423280 | 16.19 | | CumMax | float32 | [512 1024 7 7] | -1 | TRUE | random | fwd | 78490449 | 4814360 | 16.30 | | CumMax | float32 | [512 1024 100] | -1 | TRUE | random | fwd | 11320257 | 1490630 | 7.59 | | CumMax | float32 | [1024 1024 7 7] | -1 | TRUE | random | fwd | 157026754 | 9636110 | 16.30 | | CumMax | float32 | [1024 1024 100] | -1 | TRUE | random | fwd | 22649554 | 2982970 | 7.59 | | CumMin | float32 | [512 64 112 112] | -1 | TRUE | random | fwd | 79317382 | 10511900 | 7.55 | | CumMin | float32 | [512 64 56 56] | -1 | TRUE | random | fwd | 39419030 | 2529820 | 15.58 | | CumMin | float32 | [512 128 56 56] | -1 | TRUE | random | fwd | 78850445 | 9170440 | 8.60 | | CumMin | float32 | [512 128 28 28] | -1 | TRUE | random | fwd | 39393495 | 2515360 | 15.66 | | CumMin | float32 | [512 256 28 28] | -1 | TRUE | random | fwd | 78737166 | 5027230 | 15.66 | | CumMin | float32 | [512 256 14 14] | -1 | TRUE | random | fwd | 39270408 | 2443900 | 16.07 | | CumMin | float32 | [512 512 14 14] | -1 | TRUE | random | fwd | 79465092 | 4885980 | 16.26 | | CumMin | float32 | [512 512 7 7] | -1 | TRUE | random | fwd | 39264681 | 2410990 | 16.29 | | CumMin | float32 | [512 1024 7 7] | -1 | TRUE | random | fwd | 78513042 | 4815270 | 16.31 | | CumMin | float32 | [512 1024 100] | -1 | TRUE | random | fwd | 11321649 | 1490860 | 7.59 | | CumMin | float32 | [1024 1024 7 7] | -1 | TRUE | random | fwd | 157041875 | 9633820 | 16.30 | | CumMin | float32 | [1024 1024 100] | -1 | TRUE | random | fwd | 22661778 | 5730720 | 3.95 | | CumSum | float32 | [512 64 112 112] | -1 | TRUE | random | fwd | 37420899 | 7051980 | 5.31 | | CumSum | float32 | [512 64 56 56] | -1 | TRUE | random | fwd | 18553115 | 2330090 | 7.96 | | CumSum | float32 | [512 128 56 56] | -1 | TRUE | random | fwd | 37096775 | 4656530 | 7.97 | | CumSum | float32 | [512 128 28 28] | -1 | TRUE | random | fwd | 18498636 | 2312900 | 8.00 | | CumSum | float32 | [512 256 28 28] | -1 | TRUE | random | fwd | 37008008 | 4623340 | 8.00 | | CumSum | float32 | [512 256 14 14] | -1 | TRUE | random | fwd | 18427773 | 2301890 | 8.01 | | CumSum | float32 | [512 512 14 14] | -1 | TRUE | random | fwd | 36886474 | 4601850 | 8.02 | | CumSum | float32 | [512 512 7 7] | -1 | TRUE | random | fwd | 18399326 | 2293910 | 8.02 | | CumSum | float32 | [512 1024 7 7] | -1 | TRUE | random | fwd | 36863786 | 4586130 | 8.04 | | CumSum | float32 | [512 1024 100] | -1 | TRUE | random | fwd | 5337701 | 998352 | 5.35 | | CumSum | float32 | [1024 1024 7 7] | -1 | TRUE | random | fwd | 75153089 | 9171500 | 8.19 | | CumSum | float32 | [1024 1024 100] | -1 | TRUE | random | fwd | 10686874 | 1993500 | 5.36 | | CumProd | float32 | [512 64 112 112] | -1 | TRUE | random | fwd | 37492178 | 7043960 | 5.32 | | CumProd | float32 | [512 64 56 56] | -1 | TRUE | random | fwd | 18602251 | 2328130 | 7.99 | | CumProd | float32 | [512 128 56 56] | -1 | TRUE | random | fwd | 37180790 | 4653930 | 7.99 | | CumProd | float32 | [512 128 28 28] | -1 | TRUE | random | fwd | 18552732 | 2312610 | 8.02 | | CumProd | float32 | [512 256 28 28] | -1 | TRUE | random | fwd | 37102295 | 4625170 | 8.02 | | CumProd | float32 | [512 256 14 14] | -1 | TRUE | random | fwd | 18471901 | 2303490 | 8.02 | | CumProd | float32 | [512 512 14 14] | -1 | TRUE | random | fwd | 36980297 | 4605450 | 8.03 | | CumProd | float32 | [512 512 7 7] | -1 | TRUE | random | fwd | 18449117 | 2295490 | 8.04 | | CumProd | float32 | [512 1024 7 7] | -1 | TRUE | random | fwd | 36929706 | 4589030 | 8.05 | | CumProd | float32 | [512 1024 100] | -1 | TRUE | random | fwd | 5350325 | 996876 | 5.37 | | CumProd | float32 | [1024 1024 7 7] | -1 | TRUE | random | fwd | 73828228 | 9180310 | 8.04 | | CumProd | float32 | [1024 1024 100] | -1 | TRUE | random | fwd | 10692522 | 1992210 | 5.37 |
bfloat16
| op_name | dtype | size | dim | contiguous | model | direction | ROCm pytorch | MIOpen HIP | Improvement | |:-------:|:--------:|:----------------:|:---:|------------|:------:|:---------:|:------------:|:----------:|:-----------:| | CumMax | bfloat16 | [512 64 112 112] | -1 | TRUE | random | fwd | 82001795 | 10583800 | 7.75 | | CumMax | bfloat16 | [512 64 56 56] | -1 | TRUE | random | fwd | 40779253 | 2538400 | 16.06 | | CumMax | bfloat16 | [512 128 56 56] | -1 | TRUE | random | fwd | 81604184 | 5080060 | 16.06 | | CumMax | bfloat16 | [512 128 28 28] | -1 | TRUE | random | fwd | 40773765 | 2517520 | 16.20 | | CumMax | bfloat16 | [512 256 28 28] | -1 | TRUE | random | fwd | 81536393 | 5032490 | 16.20 | | CumMax | bfloat16 | [512 256 14 14] | -1 | TRUE | random | fwd | 40759269 | 2462850 | 16.55 | | CumMax | bfloat16 | [512 512 14 14] | -1 | TRUE | random | fwd | 81497354 | 4925450 | 16.55 | | CumMax | bfloat16 | [512 512 7 7] | -1 | TRUE | random | fwd | 41580409 | 2433980 | 17.08 | | CumMax | bfloat16 | [512 1024 7 7] | -1 | TRUE | random | fwd | 81399196 | 4872310 | 16.71 | | CumMax | bfloat16 | [512 1024 100] | -1 | TRUE | random | fwd | 11702748 | 1502110 | 7.79 | | CumMax | bfloat16 | [1024 1024 7 7] | -1 | TRUE | random | fwd | 162846550 | 9719290 | 16.75 | | CumMax | bfloat16 | [1024 1024 100] | -1 | TRUE | random | fwd | 23391864 | 2999200 | 7.80 | | CumMin | bfloat16 | [512 64 112 112] | -1 | TRUE | random | fwd | 82027507 | 10510000 | 7.80 | | CumMin | bfloat16 | [512 64 56 56] | -1 | TRUE | random | fwd | 40770069 | 2529640 | 16.12 | | CumMin | bfloat16 | [512 128 56 56] | -1 | TRUE | random | fwd | 81606825 | 5061550 | 16.12 | | CumMin | bfloat16 | [512 128 28 28] | -1 | TRUE | random | fwd | 40762245 | 2513820 | 16.22 | | CumMin | bfloat16 | [512 256 28 28] | -1 | TRUE | random | fwd | 81501883 | 5028670 | 16.21 | | CumMin | bfloat16 | [512 256 14 14] | -1 | TRUE | random | fwd | 40744486 | 2462740 | 16.54 | | CumMin | bfloat16 | [512 512 14 14] | -1 | TRUE | random | fwd | 81500475 | 4921650 | 16.56 | | CumMin | bfloat16 | [512 512 7 7] | -1 | TRUE | random | fwd | 40697830 | 2433980 | 16.72 | | CumMin | bfloat16 | [512 1024 7 7] | -1 | TRUE | random | fwd | 81402956 | 4870530 | 16.71 | | CumMin | bfloat16 | [512 1024 100] | -1 | TRUE | random | fwd | 11700876 | 1492480 | 7.84 | | CumMin | bfloat16 | [1024 1024 7 7] | -1 | TRUE | random | fwd | 162799578 | 9718710 | 16.75 | | CumMin | bfloat16 | [1024 1024 100] | -1 | TRUE | random | fwd | 23387721 | 2980940 | 7.85 | | CumSum | bfloat16 | [512 64 112 112] | -1 | TRUE | random | fwd | 46814849 | 6889390 | 6.80 | | CumSum | bfloat16 | [512 64 56 56] | -1 | TRUE | random | fwd | 23282362 | 2320860 | 10.03 | | CumSum | bfloat16 | [512 128 56 56] | -1 | TRUE | random | fwd | 46555589 | 6526530 | 7.13 | | CumSum | bfloat16 | [512 128 28 28] | -1 | TRUE | random | fwd | 23230827 | 2307880 | 10.07 | | CumSum | bfloat16 | [512 256 28 28] | -1 | TRUE | random | fwd | 46477910 | 4613740 | 10.07 | | CumSum | bfloat16 | [512 256 14 14] | -1 | TRUE | random | fwd | 23206284 | 2299330 | 10.09 | | CumSum | bfloat16 | [512 512 14 14] | -1 | TRUE | random | fwd | 46414775 | 4595450 | 10.10 | | CumSum | bfloat16 | [512 512 7 7] | -1 | TRUE | random | fwd | 23198811 | 2292080 | 10.12 | | CumSum | bfloat16 | [512 1024 7 7] | -1 | TRUE | random | fwd | 46428582 | 4581610 | 10.13 | | CumSum | bfloat16 | [512 1024 100] | -1 | TRUE | random | fwd | 6686083 | 978317 | 6.83 | | CumSum | bfloat16 | [1024 1024 7 7] | -1 | TRUE | random | fwd | 92810670 | 9164860 | 10.13 | | CumSum | bfloat16 | [1024 1024 100] | -1 | TRUE | random | fwd | 13372485 | 1954060 | 6.84 | | CumProd | bfloat16 | [512 64 112 112] | -1 | TRUE | random | fwd | 46990639 | 6894460 | 6.82 | | CumProd | bfloat16 | [512 64 56 56] | -1 | TRUE | random | fwd | 23371545 | 2323120 | 10.06 | | CumProd | bfloat16 | [512 128 56 56] | -1 | TRUE | random | fwd | 46773218 | 4644510 | 10.07 | | CumProd | bfloat16 | [512 128 28 28] | -1 | TRUE | random | fwd | 23333338 | 2310600 | 10.10 | | CumProd | bfloat16 | [512 256 28 28] | -1 | TRUE | random | fwd | 46674003 | 4619370 | 10.10 | | CumProd | bfloat16 | [512 256 14 14] | -1 | TRUE | random | fwd | 23306058 | 2302690 | 10.12 | | CumProd | bfloat16 | [512 512 14 14] | -1 | TRUE | random | fwd | 46625844 | 4603780 | 10.13 | | CumProd | bfloat16 | [512 512 7 7] | -1 | TRUE | random | fwd | 23304010 | 2295150 | 10.15 | | CumProd | bfloat16 | [512 1024 7 7] | -1 | TRUE | random | fwd | 46605092 | 4588370 | 10.16 | | CumProd | bfloat16 | [512 1024 100] | -1 | TRUE | random | fwd | 6709842 | 979597 | 6.85 | | CumProd | bfloat16 | [1024 1024 7 7] | -1 | TRUE | random | fwd | 93215385 | 9178690 | 10.16 | | CumProd | bfloat16 | [1024 1024 100] | -1 | TRUE | random | fwd | 13421396 | 1954820 | 6.87 |

Average over all cases:

type average
float16 10.42
float32 10.43
bfloat16 11.32
long10024070 commented 1 month ago

This PR is a continuation of PR #3182. Accidently, I have closed the older PR, and then made change to the working branch, which makes me cannot reopen the older once. There are not many comments in that PR, I hope it doesn't interrupt your reviewing process.

And again, sorry for this Inconvenience.