danthe3rd / xformers

Hackable and optimized Transformers building blocks, supporting a composable construction.
https://facebookresearch.github.io/xformers/
Other
0 stars 0 forks source link

WIP matmull cutlass #1

Open danthe3rd opened 2 years ago

danthe3rd commented 2 years ago

just the matmull:

[------------------------------------------------------------------------------------------------------------------------- attention -------------------------------------------------------------------------------------------------------------------------]
                          |  cutlass_dotp  |  cutlass_dotp_32x64x8  |  cutlass_dotp_32x128x8  |  cutlass_dotp_64x64x8  |  cutlass_dotp_16x16x16  |  cutlass_dotp_32x32x32  |  cutlass_dotp_32x64x8_32x32x8  |  cutlass_dotp_64x64x8_32x32x8  |  vanilla_matmull
1 threads: ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
      B=8, M=128, K=32    |       32.1     |          24.2          |           31.9          |           56.2         |           23.7          |            34.7         |              22.1              |              23.1              |        18.3     
      B=8, M=128, K=64    |       34.5     |          24.4          |           53.7          |          100.2         |           33.0          |            54.7         |              22.0              |              22.4              |        18.5     
      B=8, M=128, K=128   |       61.1     |          35.9          |           98.3          |          188.4         |           57.9          |            94.7         |              22.4              |              26.3              |        18.6     
      B=8, M=1024, K=32   |      264.2     |         302.5          |          346.7          |          491.6         |          262.6          |           525.6         |             278.0              |             293.9              |       100.6     
      B=8, M=1024, K=64   |      440.7     |         381.6          |          509.8          |          750.1         |          411.2          |           758.7         |             339.1              |             365.1              |       161.8     
      B=8, M=1024, K=128  |     1036.9     |         560.7          |          769.2          |         1290.3         |          807.7          |          1254.3         |             496.6              |             520.8              |       286.8 

[------------------------------------------------------------ attention ------------------------------------------------------------]
                          |  cutlass_dotp_32x64x8_32x32x8  |  cutlass_32x64x8_16x16x8  |  cutlass_64x64x8_32x64x8  |  vanilla_matmull
1 threads: --------------------------------------------------------------------------------------------------------------------------
      B=8, M=128, K=32    |              22.1              |            23.6           |            24.0           |        18.2     
      B=8, M=128, K=64    |              22.0              |            23.1           |            23.9           |        18.1     
      B=8, M=128, K=128   |              22.4              |            22.9           |            33.4           |        18.6     
      B=8, M=1024, K=32   |             278.0              |           261.4           |           303.0           |       100.5     
      B=8, M=1024, K=64   |             339.1              |           360.4           |           375.9           |       160.7     
      B=8, M=1024, K=128  |             496.6              |           588.2           |           524.3           |       286.5     

[---------------------------------------------- attention -----------------------------------------------]
                          |  cutlass_dotp_32x64x8_32x32x8  |  cutlass_16x128x8_16x64x8  |  vanilla_matmull
1 threads: -----------------------------------------------------------------------------------------------
      B=8, M=128, K=32    |              22.1              |            32.4            |        28.9     
      B=8, M=128, K=64    |              22.0              |            32.4            |        29.5     
      B=8, M=128, K=128   |              22.4              |            32.5            |        28.7     
      B=8, M=1024, K=32   |             278.0              |           315.5            |       100.6     
      B=8, M=1024, K=64   |             339.1              |           403.8            |       160.7     
      B=8, M=1024, K=128  |             496.6              |           586.8            |       286.5 
danthe3rd commented 2 years ago

85dd034

[--------------------------------------------------------------- attention ---------------------------------------------------------------]
                          |  cutlass_16x64x8  |  cutlass_32x128x8  |  cutlass_16x128x8  |  attention2  |  cutlass_16x128x8_v100  |  vanilla
1 threads: --------------------------------------------------------------------------------------------------------------------------------
      B=8, M=128, K=32    |         89.2      |        145.2       |         81.1       |     233.7    |            63.3         |     88.9
      B=8, M=128, K=64    |        101.9      |        152.7       |         89.4       |     439.4    |            65.1         |     89.2
      B=8, M=128, K=128   |        184.4      |        170.9       |        119.7       |     879.1    |            79.0         |     89.3
      B=8, M=1024, K=32   |       1657.5      |       3226.1       |       2311.6       |    5759.2    |          1406.1         |    332.1
      B=8, M=1024, K=64   |       2322.9      |       3606.6       |       2847.1       |   10336.3    |          1777.7         |    373.5
      B=8, M=1024, K=128  |       5060.0      |       5075.3       |       4560.6       |   24064.9    |          2664.9         |    492.5
      B=8, M=128, K=256   |                   |                    |        221.6       |              |           157.6         |     89.1
      B=8, M=1024, K=256  |                   |                    |       9312.8       |              |          5439.0         |    847.7
      B=8, M=2048, K=32   |                   |                    |       8228.8       |              |          4629.1         |   1025.2
      B=8, M=2048, K=64   |                   |                    |      10399.7       |              |          6111.3         |   1222.6
      B=8, M=2048, K=128  |                   |                    |      17878.4       |              |          9035.3         |   1884.0
      B=8, M=2048, K=256  |                   |                    |      36440.1       |              |         19212.5         |   3281.2
danthe3rd commented 2 years ago

73bd4c1

[-------------------------------------------- attention --------------------------------------------]
                          |  cutlass_16x128x8_siPrecacl2  |  cutlass2  |  cutlass2_bounds  |  vanilla
1 threads: ------------------------------------------------------------------------------------------
      B=8, M=128, K=32    |               59.8            |            |         40.3      |     66.4
      B=8, M=128, K=64    |               66.2            |            |                   |         
      B=8, M=128, K=128   |               84.4            |     50.8   |         52.4      |     66.8
      B=8, M=128, K=256   |              146.9            |     79.4   |         81.3      |     65.8
      B=8, M=1024, K=32   |             1561.3            |            |                   |         
      B=8, M=1024, K=64   |             1898.1            |            |                   |         
      B=8, M=1024, K=128  |             2910.1            |   1619.8   |                   |         
      B=8, M=1024, K=256  |             6118.0            |   2702.9   |                   |         
      B=8, M=2048, K=32   |             5602.2            |            |       5008.1      |   2277.6
      B=8, M=2048, K=64   |             6910.0            |            |                   |         
      B=8, M=2048, K=128  |            10999.5            |   6245.5   |       6448.7      |   3121.8
      B=8, M=2048, K=256  |            23657.9            |  10823.6   |      11197.1      |   5183.5
      B=8, M=4096, K=128  |                               |  23813.0   |      24924.0      |  12123.6
      B=8, M=4096, K=256  |                               |  42068.8   |      44655.8      |  19977.7
      B=8, M=64, K=8      |                               |            |         34.4      |     65.3
      B=8, M=64, K=16     |                               |            |         34.7      |     66.1
      B=8, M=64, K=32     |                               |            |         34.9      |     66.8
      B=8, M=64, K=128    |                               |            |         36.1      |     66.3
      B=8, M=64, K=256    |                               |            |         55.2      |     66.0
      B=8, M=127, K=8     |                               |            |         37.0      |     66.2
      B=8, M=127, K=16    |                               |            |         37.9      |     66.6
      B=8, M=127, K=32    |                               |            |         40.2      |     66.9
      B=8, M=127, K=128   |                               |            |         52.4      |     66.4
      B=8, M=127, K=256   |                               |            |         81.2      |     67.0
      B=8, M=128, K=8     |                               |            |         36.7      |     66.4
      B=8, M=128, K=16    |                               |            |         37.7      |     66.4
      B=8, M=256, K=8     |                               |            |        117.9      |     65.9
      B=8, M=256, K=16    |                               |            |        120.7      |     65.6
      B=8, M=256, K=32    |                               |            |        125.3      |     66.3
      B=8, M=256, K=128   |                               |            |        166.5      |     75.7
      B=8, M=256, K=256   |                               |            |        292.3      |    127.3
      B=8, M=2048, K=8    |                               |            |       4623.5      |   2143.2
      B=8, M=2048, K=16   |                               |            |       4738.0      |   2169.6
      B=8, M=4096, K=8    |                               |            |      17678.3      |   8092.9
      B=8, M=4096, K=16   |                               |            |      18077.2      |   8314.3
      B=8, M=4096, K=32   |                               |            |      18998.2      |   8912.3
danthe3rd commented 2 years ago

21641fa

[------------------------------------------------------ attention -------------------------------------------------------]
                          |  cutlass_16x128x8_siPrecacl2  |  cutlass2  |  cutlass2_bounds  |  cutlass2_bounds2  |  vanilla
1 threads: ---------------------------------------------------------------------------------------------------------------
      B=8, M=128, K=32    |               59.8            |            |         40.2      |         34.8       |     66.7
      B=8, M=128, K=64    |               66.2            |            |                   |                    |         
      B=8, M=128, K=128   |               84.4            |     50.8   |         52.6      |         43.8       |     66.3
      B=8, M=128, K=256   |              146.9            |     79.4   |         81.2      |         75.4       |     67.0
      B=8, M=1024, K=32   |             1561.3            |            |       1266.9      |        949.9       |    520.5
      B=8, M=1024, K=64   |             1898.1            |            |                   |                    |         
      B=8, M=1024, K=128  |             2910.1            |   1619.8   |       1652.8      |       1403.7       |    737.0
      B=8, M=1024, K=256  |             6118.0            |   2702.9   |       2813.2      |       2699.2       |   1330.9
      B=8, M=2048, K=32   |             5602.2            |            |                   |                    |         
      B=8, M=2048, K=64   |             6910.0            |            |                   |                    |         
      B=8, M=2048, K=128  |            10999.5            |   6245.5   |                   |                    |         
      B=8, M=2048, K=256  |            23657.9            |  10823.6   |                   |                    |         
      B=8, M=4096, K=128  |                               |  23813.0   |      24679.0      |      21202.0       |  12016.8
      B=8, M=4096, K=256  |                               |  42068.8   |      44072.9      |      42345.7       |  19873.5
      B=8, M=96, K=31     |                               |            |         36.5      |         34.8       |     65.7
      B=8, M=96, K=32     |                               |            |         36.8      |         34.7       |     65.3
      B=8, M=96, K=128    |                               |            |         38.4      |         35.0       |     66.1
      B=8, M=96, K=256    |                               |            |         59.9      |         53.2       |     66.8
      B=8, M=127, K=31    |                               |            |         40.3      |         34.9       |     66.1
      B=8, M=127, K=32    |                               |            |         40.2      |         34.3       |     65.6
      B=8, M=127, K=128   |                               |            |         52.4      |         43.8       |     66.2
      B=8, M=127, K=256   |                               |            |         81.1      |         75.5       |     66.7
      B=8, M=128, K=31    |                               |            |         40.4      |         35.1       |     66.7
      B=8, M=1024, K=31   |                               |            |       1271.3      |        957.1       |    523.3
      B=8, M=4096, K=31   |                               |            |      18935.7      |      14102.2       |   8957.7
      B=8, M=4096, K=32   |                               |            |      18865.5      |      14000.8       |   8905.5
danthe3rd commented 2 years ago

8f4728a56d02c629393e60438c0945ca60ef3b78

[------------------------------------------------------------------------- attention --------------------------------------------------------------------------]
                          |  cutlass2_bounds_v100  |  cutlass2_bounds_autoreg_v100  |  cutlass2_bounds_reg160_v100  |  cutlass2_bounds_autoreg2_v100  |  vanilla
1 threads: -----------------------------------------------------------------------------------------------------------------------------------------------------
      B=8, M=96, K=31     |           35.0         |               35.1             |               36.8            |                34.8             |     88.7
      B=8, M=96, K=32     |           35.3         |               35.0             |               36.8            |                34.3             |     89.1
      B=8, M=96, K=128    |           41.4         |               36.5             |               37.0            |                34.9             |     90.2
      B=8, M=96, K=256    |           67.0         |               51.2             |               49.5            |                48.9             |     89.8
      B=8, M=127, K=31    |           36.1         |               35.6             |               37.2            |                34.6             |     89.7
      B=8, M=127, K=32    |           35.8         |               35.5             |               36.5            |                34.5             |     89.0
      B=8, M=127, K=128   |           46.0         |               35.3             |               37.1            |                34.6             |     91.4
      B=8, M=127, K=256   |           74.4         |               55.8             |               54.0            |                53.2             |     89.9
      B=8, M=128, K=31    |           35.8         |               35.5             |               36.6            |                34.3             |     90.1
      B=8, M=128, K=32    |           36.0         |               35.7             |               36.3            |                34.6             |     89.3
      B=8, M=128, K=128   |           46.1         |               35.5             |               37.1            |                34.8             |     89.2
      B=8, M=128, K=256   |           74.8         |               56.0             |               54.3            |                53.4             |     89.6
      B=8, M=1024, K=31   |         1040.0         |              631.4             |              552.7            |               554.9             |    375.9
      B=8, M=1024, K=32   |         1031.7         |              634.1             |              550.5            |               553.6             |    370.7
      B=8, M=1024, K=128  |         1409.2         |             1006.3             |              911.2            |               897.8             |    517.3
      B=8, M=1024, K=256  |         2376.6         |             1799.2             |             1661.1            |              1660.1             |    899.5
      B=8, M=4096, K=31   |        13739.2         |             8341.8             |             7315.9            |              7313.5             |   3547.7
      B=8, M=4096, K=32   |        13555.8         |             8325.9             |             7287.1            |              7289.8             |   3477.7
      B=8, M=4096, K=128  |        18125.9         |            12822.6             |            11345.4            |             11282.0             |   7810.2
      B=8, M=4096, K=256  |        30331.0         |            22973.0             |            21350.0            |             21329.5             |  13067.7
danthe3rd commented 2 years ago
[------------------- attention -------------------]
                           |  optimized  |  vanilla
1 threads: ----------------------------------------
      B=7680, M=36, K=32   |    1965.7   |    694.4
      B=7680, M=36, K=64   |    2343.0   |    949.1
      B=7680, M=48, K=32   |    1987.7   |    879.4
      B=7680, M=48, K=64   |    2475.8   |   1219.1
      B=7680, M=64, K=32   |    2842.9   |   1299.9
      B=7680, M=64, K=64   |    3523.2   |   1904.5
      B=7680, M=96, K=32   |    4956.6   |   2292.8
      B=7680, M=96, K=64   |    5952.7   |   3430.5
      B=7680, M=128, K=32  |    7563.3   |   3858.8
      B=7680, M=128, K=64  |    8750.9   |   5048.0
      B=7680, M=192, K=32  |   18231.2   |   7653.3
      B=7680, M=192, K=64  |   21890.0   |  11069.7
      B=768, M=36, K=32    |     199.1   |     95.5
      B=768, M=36, K=64    |     241.5   |    126.3
      B=768, M=48, K=32    |     209.3   |    109.7
      B=768, M=48, K=64    |     259.8   |    156.6
      B=768, M=64, K=32    |     297.9   |    152.7
      B=768, M=64, K=64    |     364.6   |    216.1
      B=768, M=96, K=32    |     507.3   |    255.3
      B=768, M=96, K=64    |     606.3   |    367.3
      B=768, M=128, K=32   |     755.1   |    402.3
      B=768, M=128, K=64   |     890.1   |    543.6
      B=768, M=192, K=32   |    1841.1   |    798.7
      B=768, M=192, K=64   |    2202.8   |   1132.5
      B=1536, M=36, K=32   |     394.5   |    161.0
      B=1536, M=36, K=64   |     475.0   |    220.5
      B=1536, M=48, K=32   |     407.6   |    197.4
      B=1536, M=48, K=64   |     506.1   |    275.6
      B=1536, M=64, K=32   |     588.6   |    282.2
      B=1536, M=64, K=64   |     717.4   |    404.1
      B=1536, M=96, K=32   |    1002.8   |    485.3
      B=1536, M=96, K=64   |    1198.1   |    714.5
      B=1536, M=128, K=32  |    1508.2   |    778.0
      B=1536, M=128, K=64  |    1762.5   |   1079.9
      B=1536, M=192, K=32  |    3662.7   |   1557.7
      B=1536, M=192, K=64  |    4386.4   |   2239.0

with kNumWarpsPerBlock=2

[------------------- attention -------------------]
                           |  optimized  |  vanilla
1 threads: ----------------------------------------
      B=7680, M=36, K=64   |    1842.3   |    950.1
      B=768, M=192, K=64   |    2040.4   |   1133.6
      B=1536, M=192, K=64  |    4151.8   |   2241.4
      B=5376, M=36, K=64   |    1288.5   |    675.9
      B=544, M=192, K=64   |    1434.0   |    815.1
      B=512, M=192, K=64   |    1348.8   |    767.2
      B=1088, M=192, K=64  |    2913.7   |   1600.2
      B=1024, M=192, K=64  |    2740.3   |   1500.1
      B=768, M=154, K=64   |    1606.7   |    941.4
      B=1536, M=154, K=64  |    3247.5   |   1838.4
      B=544, M=154, K=64   |    1134.2   |    676.1
      B=512, M=154, K=64   |    1067.8   |    636.9
      B=1088, M=154, K=64  |    2288.8   |   1312.3
      B=1024, M=154, K=64  |    2153.4   |   1240.5

[----------------- attention ------------------]
                        |  optimized  |  vanilla
1 threads: -------------------------------------
      B=8, M=36, K=64   |     37.5    |    95.5
      B=8, M=192, K=64  |     61.3    |   128.7
      B=8, M=154, K=64  |     58.9    |   127.8
danthe3rd commented 2 years ago

matmull1 can be improved:

[----------------------------------------------------------------- attention ------------------------------------------------------------------]
                          |  cutlass2_bounds_v100  |  cutlass2_bounds_autoreg_v100  |  cutlass2_bounds_reg160_v100  |  c2_nomm1_v100  |  vanilla
1 threads: -------------------------------------------------------------------------------------------------------------------------------------
      B=8, M=96, K=31     |           35.0         |               35.1             |               36.8            |                 |
      B=8, M=96, K=32     |           35.3         |               35.0             |               36.8            |                 |
      B=8, M=96, K=128    |           41.4         |               36.5             |               37.0            |                 |
      B=8, M=96, K=256    |           67.0         |               51.2             |               49.5            |                 |
      B=8, M=127, K=31    |           36.1         |               35.6             |               37.2            |                 |
      B=8, M=127, K=32    |           35.8         |               35.5             |               36.5            |                 |
      B=8, M=127, K=128   |           46.0         |               35.3             |               37.1            |                 |
      B=8, M=127, K=256   |           74.4         |               55.8             |               54.0            |                 |
      B=8, M=128, K=31    |           35.8         |               35.5             |               36.6            |                 |
      B=8, M=128, K=32    |           36.0         |               35.7             |               36.3            |        34.2     |     85.0
      B=8, M=128, K=128   |           46.1         |               35.5             |               37.1            |        34.5     |     85.2
      B=8, M=128, K=256   |           74.8         |               56.0             |               54.3            |        38.1     |     85.5
      B=8, M=1024, K=31   |         1040.0         |              631.4             |              552.7            |                 |
      B=8, M=1024, K=32   |         1031.7         |              634.1             |              550.5            |       395.5     |    370.5
      B=8, M=1024, K=128  |         1409.2         |             1006.3             |              911.2            |       474.3     |    517.2
      B=8, M=1024, K=256  |         2376.6         |             1799.2             |             1661.1            |       879.1     |    899.7
      B=8, M=4096, K=31   |        13739.2         |             8341.8             |             7315.9            |                 |
      B=8, M=4096, K=32   |        13555.8         |             8325.9             |             7287.1            |      5037.9     |   3459.8
      B=8, M=4096, K=128  |        18125.9         |            12822.6             |            11345.4            |      5503.3     |   7797.6
      B=8, M=4096, K=256  |        30331.0         |            22973.0             |            21350.0            |     10545.0     |  13076.7
danthe3rd commented 2 years ago

based = before optimized = b08a5c8

[---------------------------------- attention -----------------------------------]
                           |  base  |  nomm1  |  nomm1mi  |  optimized  |  vanilla
1 threads: -----------------------------------------------------------------------
      B=7680, M=36, K=64   |  4.9   |   2.5   |    4.6    |     4.4     |    2.2  
      B=768, M=192, K=64   |  4.4   |   2.5   |    4.1    |     4.0     |    2.1  
      B=1536, M=192, K=64  |  8.7   |   5.0   |    8.1    |     8.0     |    4.1  
      B=5376, M=36, K=64   |  3.5   |   1.8   |    3.2    |     3.1     |    1.6 
danthe3rd commented 2 years ago

new = 9f91e09 new_w2 = new with kNumWarpsPerBlock=2 resetcutlass = without my mods to cutlass to load shared mem

[----------------------------------------------------- attention ------------------------------------------------------]
                           |  base  |  new  |  new_w2  |  new_nomm1  |  new_nomm2_resetcutlass  |  new_nomm2  |  vanilla
1 threads: -------------------------------------------------------------------------------------------------------------
      B=7680, M=36, K=64   |  4.6   |  4.1  |   2.9    |     2.0     |           2.2            |     2.4     |    2.2  
      B=768, M=192, K=64   |  4.2   |  3.9  |   3.2    |     2.2     |           1.6            |     1.7     |    2.1  
      B=1536, M=192, K=64  |  8.2   |  7.7  |   6.4    |     4.4     |           3.3            |     3.5     |    4.1  
      B=5376, M=36, K=64   |  3.1   |  2.9  |   2.1    |     1.4     |           1.5            |     1.7     |    1.6
danthe3rd commented 2 years ago

Potential for improvement by handling si better:

[------------------------------------------------- attention -------------------------------------------------]
                            |  base  |  no_si_ld  |  no_si_ld_nomods  |  no_si_ld_nomods_noldoutput  |  vanilla
1 threads: ----------------------------------------------------------------------------------------------------
      B=7680, M=36, K=64    |   4.1  |     4.2    |         3.6       |              3.4             |     2.2 
      B=768, M=784, K=64    |  54.7  |    56.5    |        50.0       |             46.7             |    31.7 
      B=128, M=2048, K=128  |  77.9  |    79.9    |        70.6       |             66.1             |    46.3
danthe3rd commented 2 years ago

cutlass batched gemm example (5):


using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;

  int const m = 2048;
  int const n = 2048;
  int const k = 128;
  int const batch_count = 128;

Built target 05_batched_gemm
Running array gemm
cutlass_array_sgemm --> 13758[µs]
Passed.
Running strided batched gemm
cutlass_strided_batched_sgemm --> 13571[µs]
Passed.
Running array gemm
cutlass_array_sgemm --> 13027[µs]
Passed.
Running strided batched gemm
cutlass_strided_batched_sgemm --> 12887[µs]
Passed.
danthe3rd commented 2 years ago

after: 56c3ef3

[-------------------------- attention --------------------------]
                            |  base_v100  |  cmp_v100  |  vanilla
1 threads: ------------------------------------------------------
      B=1, M=20480, K=128   |   32035.9   |  31483.8   |  23315.1
      B=7680, M=36, K=64    |    1697.0   |   1664.1   |    887.4
      B=768, M=784, K=64    |   25137.3   |  24666.7   |  17861.5
      B=128, M=2048, K=128  |   38158.1   |  37305.4   |  27872.0
danthe3rd commented 2 years ago

tune q

[---------------------------------- attention ----------------------------------]
                             |  q32_v100  |  q16_v100  |  q32_w2_v100  |  vanilla
1 threads: ----------------------------------------------------------------------
      B=64, M=128, K=32      |      61.5  |      55.0  |        59.7   |     93.0
      B=64, M=128, K=64      |      74.8  |      67.1  |        74.4   |     95.1
      B=64, M=128, K=128     |     103.8  |      97.5  |       126.5   |     94.8
      B=64, M=128, K=256     |     199.7  |     193.2  |       255.2   |    143.0
      B=64, M=1024, K=32     |    2624.5  |    2675.6  |      2859.2   |   1808.2
      B=64, M=1024, K=64     |    3150.5  |    3351.8  |      3541.1   |   2236.5
      B=64, M=1024, K=128    |    4278.5  |    4851.4  |      6130.6   |   3453.7
      B=64, M=1024, K=256    |    8177.7  |    9618.6  |     11313.3   |   6076.0
      B=64, M=4096, K=32     |   40869.8  |   41800.3  |     42976.4   |  29175.0
      B=64, M=4096, K=64     |   48617.7  |   52187.0  |     52658.0   |  35570.7
      B=64, M=4096, K=128    |   66354.5  |   73982.8  |     93704.2   |  54904.6
      B=64, M=4096, K=256    |  128790.8  |  147621.5  |    178471.5   |  96703.2
      B=256, M=128, K=16     |     175.5  |     164.1  |       192.6   |    120.8
      B=256, M=128, K=32     |     195.8  |     183.7  |       216.6   |    142.8
      B=256, M=128, K=64     |     246.0  |     235.6  |       291.1   |    198.6
      B=256, M=128, K=128    |     343.9  |     345.5  |       502.9   |    326.5
      B=256, M=1024, K=16    |    9468.1  |    9294.3  |      9837.2   |   6995.3
      B=256, M=1024, K=32    |   10337.5  |   10553.5  |     10889.0   |   7663.4
      B=256, M=1024, K=64    |   12403.2  |   13246.0  |     13361.2   |   8855.1
      B=256, M=1024, K=128   |   16954.5  |   18884.8  |     23965.6   |  14248.1
      B=1024, M=128, K=16    |     669.7  |     632.9  |       676.9   |    438.6
      B=1024, M=128, K=32    |     732.8  |     714.1  |       748.6   |    518.9
      B=1024, M=128, K=64    |     908.9  |     911.4  |       966.5   |    676.6
      B=1024, M=128, K=128   |    1282.2  |    1349.3  |      1744.1   |   1106.4
      B=1024, M=1024, K=16   |   37649.0  |   36837.9  |     39175.8   |  27497.3
      B=1024, M=1024, K=32   |   41135.6  |   41764.0  |     43614.5   |  29859.8
      B=1024, M=1024, K=64   |   49194.7  |   52581.9  |     53170.3   |  36187.4
      B=1024, M=1024, K=128  |   67194.2  |   74965.7  |     97479.2   |  55582.8
danthe3rd commented 2 years ago

f16v0 -> everything in f16 (gives wrong result and nans/infs) f16v1_accf32 -> input/output in f16, everything else in f32 (no nan, close to what we expected) f16v2_tcmm1 -> use tensorcore for first matmull (f32 = f16*f16 + f32). Had to modify kQueriesPerBlock=64 (instead of 32) and kNumWarpsPerBlock=2 (instead of 4 to use less sharedmem)

[-------------------------------------------- attention (use_attn_bias=False) --------------------------------------------]
                                                    |  f16v2_q64w2  |   f16v0   |  vanilla  |  f16v1_accf32  |  f16v2_tcmm1
1 threads: ----------------------------------------------------------------------------------------------------------------
  (Quadro_GP100)          f32 B=32, M=128, K=16     |               |     73.3  |    111.9  |       73.2     |
                          f16 B=32, M=128, K=16     |               |    188.3  |     74.3  |       57.0     |
                          f32 B=32, M=128, K=32     |               |     78.7  |     74.0  |       78.6     |
                          f16 B=32, M=128, K=32     |               |    191.1  |     74.5  |       61.1     |
                          f32 B=32, M=128, K=128    |               |    136.6  |     83.5  |      136.7     |
                          f16 B=32, M=128, K=128    |               |    215.1  |     75.4  |       98.0     |
                          f32 B=32, M=512, K=16     |               |    833.3  |    504.0  |      833.8     |
                          f16 B=32, M=512, K=16     |               |   2920.8  |    438.2  |      583.2     |
                          f32 B=32, M=512, K=32     |               |    907.5  |    530.5  |      908.3     |
                          f16 B=32, M=512, K=32     |               |   2892.3  |    468.1  |      630.6     |
                          f32 B=32, M=512, K=128    |               |   1394.1  |    759.2  |     1392.6     |
                          f16 B=32, M=512, K=128    |               |   3150.9  |    692.3  |     1006.0     |
                          f32 B=32, M=1024, K=16    |               |   3134.4  |   1828.0  |     3132.4     |
                          f16 B=32, M=1024, K=16    |               |  10324.8  |   1638.7  |     2210.9     |
                          f32 B=32, M=1024, K=32    |               |   3403.3  |   1947.8  |     3402.6     |
                          f16 B=32, M=1024, K=32    |               |  10157.6  |   1772.1  |     2457.0     |
                          f32 B=32, M=1024, K=128   |               |   5134.3  |   2832.5  |     5127.5     |
                          f16 B=32, M=1024, K=128   |               |  11435.0  |   2613.7  |     3900.7     |
                          f32 B=256, M=128, K=16    |               |    450.7  |    261.8  |      450.2     |
                          f16 B=256, M=128, K=16    |               |   1344.8  |    226.3  |      307.1     |
                          f32 B=256, M=128, K=32    |               |    499.9  |    290.7  |      499.2     |
                          f16 B=256, M=128, K=32    |               |   1357.5  |    248.4  |      334.9     |
                          f32 B=256, M=128, K=128   |               |    865.4  |    457.6  |      864.6     |
                          f16 B=256, M=128, K=128   |               |   1515.9  |    387.1  |      574.0     |
                          f32 B=256, M=512, K=16    |               |   6292.5  |   3508.8  |     6281.7     |
                          f16 B=256, M=512, K=16    |               |  18977.6  |   3050.2  |     4291.6     |
                          f32 B=256, M=512, K=32    |               |   6822.5  |   3783.0  |     6818.3     |
                          f16 B=256, M=512, K=32    |               |  18976.3  |   3324.1  |     4746.6     |
                          f32 B=256, M=512, K=128   |               |  10467.9  |   5656.4  |    10456.6     |
                          f16 B=256, M=512, K=128   |               |  21502.2  |   5028.3  |     7758.5     |
                          f32 B=256, M=1024, K=16   |               |  24512.5  |  13606.6  |    24488.6     |
                          f16 B=256, M=1024, K=16   |               |  74444.1  |  12081.0  |    16968.4     |
                          f32 B=256, M=1024, K=32   |               |  26569.0  |  14697.7  |    26532.2     |
                          f16 B=256, M=1024, K=32   |               |  75113.9  |  13117.4  |    18574.7     |
                          f32 B=256, M=1024, K=128  |               |  40343.8  |  21603.1  |    39809.1     |
                          f16 B=256, M=1024, K=128  |               |  85245.3  |  19681.6  |    30244.1     |
  (Tesla_V100_SXM2_16GB)  f16 B=32, M=128, K=16     |       84.8    |    116.9  |    101.5  |       46.5     |       81.7
                          f16 B=32, M=128, K=32     |       91.8    |    117.2  |    105.6  |       47.5     |       84.0
                          f16 B=32, M=128, K=128    |      174.5    |    135.1  |    113.9  |       54.3     |      145.0
                          f16 B=32, M=512, K=16     |      648.0    |   1631.7  |    115.8  |      346.4     |      658.7
                          f16 B=32, M=512, K=32     |      715.4    |   1535.2  |    110.3  |      379.4     |      677.2
                          f16 B=32, M=512, K=128    |     1439.5    |   1943.2  |    153.9  |      589.4     |     1140.9
                          f16 B=32, M=1024, K=16    |     2157.9    |   5831.8  |    460.4  |     1220.9     |     2396.5
                          f16 B=32, M=1024, K=32    |     2427.7    |   5860.0  |    476.3  |     1329.1     |     2468.5
                          f16 B=32, M=1024, K=128   |     5023.6    |   6456.9  |    565.7  |     2080.1     |     4142.7
                          f16 B=256, M=128, K=16    |      283.4    |    793.0  |    101.9  |      154.7     |      331.1
                          f16 B=256, M=128, K=32    |      310.8    |    795.6  |     99.8  |      170.9     |      341.0
                          f16 B=256, M=128, K=128   |      671.5    |    912.9  |    133.9  |      289.3     |      637.2
                          f16 B=256, M=512, K=16    |     3633.4    |  10531.5  |    789.4  |     2208.8     |     4289.5
                          f16 B=256, M=512, K=32    |     4128.1    |  10622.5  |    830.5  |     2401.2     |     4377.2
                          f16 B=256, M=512, K=128   |     8808.2    |  11734.8  |   1142.8  |     3903.4     |     7505.2
                          f16 B=256, M=1024, K=16   |    14273.9    |  41560.8  |   3339.2  |     8536.4     |    16685.1
                          f16 B=256, M=1024, K=32   |    16185.3    |  41815.5  |   3445.3  |     9294.8     |    16968.8
                          f16 B=256, M=1024, K=128  |    34472.9    |  45577.9  |   4216.7  |    15252.3     |    28662.0
                          f32 B=32, M=128, K=16     |               |     51.0  |    128.3  |       47.1     |
                          f32 B=32, M=128, K=32     |               |     44.8  |     99.0  |       47.0     |
                          f32 B=32, M=128, K=128    |               |     55.0  |    107.9  |       54.8     |
                          f32 B=32, M=512, K=16     |               |    369.9  |    238.7  |      369.1     |
                          f32 B=32, M=512, K=32     |               |    387.1  |    252.6  |      387.9     |
                          f32 B=32, M=512, K=128    |               |    673.1  |    485.9  |      672.7     |
                          f32 B=32, M=1024, K=16    |               |   1285.4  |    773.7  |     1285.8     |
                          f32 B=32, M=1024, K=32    |               |   1403.7  |    827.4  |     1405.3     |
                          f32 B=32, M=1024, K=128   |               |   2279.4  |   1778.5  |     2282.1     |
                          f32 B=256, M=128, K=16    |               |    176.3  |    118.7  |      176.4     |
                          f32 B=256, M=128, K=32    |               |    197.0  |    142.7  |      197.7     |
                          f32 B=256, M=128, K=128   |               |    345.6  |    334.1  |      344.7     |
                          f32 B=256, M=512, K=16    |               |   2428.0  |   1760.7  |     2429.0     |
                          f32 B=256, M=512, K=32    |               |   2664.4  |   1914.1  |     2664.3     |
                          f32 B=256, M=512, K=128   |               |   4425.9  |   3584.2  |     4423.0     |
                          f32 B=256, M=1024, K=16   |               |   9490.1  |   7022.7  |     9494.1     |
                          f32 B=256, M=1024, K=32   |               |  10378.7  |   7652.1  |    10378.5     |
                          f32 B=256, M=1024, K=128  |               |  17011.5  |  14334.8  |    16975.2     |

Times are in microseconds (us).

On A100 (f16v1_accf32):

[----------- attention (use_attn_bias=False) ------------]
                                |  f16v1_accf32  |  vanilla
1 threads: -----------------------------------------------
      f16 B=32, M=128, K=16     |       32.6    |     82.6
      f32 B=32, M=128, K=16     |       32.3    |     81.9
      f16 B=32, M=128, K=32     |       34.7    |     67.1
      f32 B=32, M=128, K=32     |       34.3    |     81.4
      f16 B=32, M=128, K=128    |       50.2    |     67.2
      f32 B=32, M=128, K=128    |       49.1    |     80.6
      f16 B=32, M=512, K=16     |      235.4    |     70.1
      f32 B=32, M=512, K=16     |      230.5    |    126.1
      f16 B=32, M=512, K=32     |      254.5    |     71.3
      f32 B=32, M=512, K=32     |      249.6    |    132.1
      f16 B=32, M=512, K=128    |      380.9    |     88.7
      f32 B=32, M=512, K=128    |      368.2    |    166.5
      f16 B=32, M=1024, K=16    |      877.4    |    268.8
      f32 B=32, M=1024, K=16    |      857.1    |    441.7
      f16 B=32, M=1024, K=32    |      948.9    |    269.8
      f32 B=32, M=1024, K=32    |      931.3    |    453.4
      f16 B=32, M=1024, K=128   |     1428.5    |    311.4
      f32 B=32, M=1024, K=128   |     1393.2    |    538.0
      f16 B=256, M=128, K=16    |      122.6    |     67.6
      f32 B=256, M=128, K=16    |      119.2    |     81.6
      f16 B=256, M=128, K=32    |      131.8    |     68.4
      f32 B=256, M=128, K=32    |      128.9    |     82.6
      f16 B=256, M=128, K=128   |      203.3    |     67.7
      f32 B=256, M=128, K=128   |      201.2    |    158.8
      f16 B=256, M=512, K=16    |     1628.9    |    475.4
      f32 B=256, M=512, K=16    |     1586.2    |    857.5
      f16 B=256, M=512, K=32    |     1752.5    |    494.7
      f32 B=256, M=512, K=32    |     1725.1    |    897.4
      f16 B=256, M=512, K=128   |     2700.5    |    614.1
      f32 B=256, M=512, K=128   |     2637.0    |   1179.4
      f16 B=256, M=1024, K=16   |     6410.0    |   1869.8
      f32 B=256, M=1024, K=16   |     6253.6    |   3200.0
      f16 B=256, M=1024, K=32   |     6941.5    |   1904.8
      f32 B=256, M=1024, K=32   |     6811.3    |   3302.3
      f16 B=256, M=1024, K=128  |    10521.9    |   2230.6
      f32 B=256, M=1024, K=128  |    10273.3    |   4033.5