Open danthe3rd opened 2 years ago
85dd034
[--------------------------------------------------------------- attention ---------------------------------------------------------------]
| cutlass_16x64x8 | cutlass_32x128x8 | cutlass_16x128x8 | attention2 | cutlass_16x128x8_v100 | vanilla
1 threads: --------------------------------------------------------------------------------------------------------------------------------
B=8, M=128, K=32 | 89.2 | 145.2 | 81.1 | 233.7 | 63.3 | 88.9
B=8, M=128, K=64 | 101.9 | 152.7 | 89.4 | 439.4 | 65.1 | 89.2
B=8, M=128, K=128 | 184.4 | 170.9 | 119.7 | 879.1 | 79.0 | 89.3
B=8, M=1024, K=32 | 1657.5 | 3226.1 | 2311.6 | 5759.2 | 1406.1 | 332.1
B=8, M=1024, K=64 | 2322.9 | 3606.6 | 2847.1 | 10336.3 | 1777.7 | 373.5
B=8, M=1024, K=128 | 5060.0 | 5075.3 | 4560.6 | 24064.9 | 2664.9 | 492.5
B=8, M=128, K=256 | | | 221.6 | | 157.6 | 89.1
B=8, M=1024, K=256 | | | 9312.8 | | 5439.0 | 847.7
B=8, M=2048, K=32 | | | 8228.8 | | 4629.1 | 1025.2
B=8, M=2048, K=64 | | | 10399.7 | | 6111.3 | 1222.6
B=8, M=2048, K=128 | | | 17878.4 | | 9035.3 | 1884.0
B=8, M=2048, K=256 | | | 36440.1 | | 19212.5 | 3281.2
73bd4c1
[-------------------------------------------- attention --------------------------------------------]
| cutlass_16x128x8_siPrecacl2 | cutlass2 | cutlass2_bounds | vanilla
1 threads: ------------------------------------------------------------------------------------------
B=8, M=128, K=32 | 59.8 | | 40.3 | 66.4
B=8, M=128, K=64 | 66.2 | | |
B=8, M=128, K=128 | 84.4 | 50.8 | 52.4 | 66.8
B=8, M=128, K=256 | 146.9 | 79.4 | 81.3 | 65.8
B=8, M=1024, K=32 | 1561.3 | | |
B=8, M=1024, K=64 | 1898.1 | | |
B=8, M=1024, K=128 | 2910.1 | 1619.8 | |
B=8, M=1024, K=256 | 6118.0 | 2702.9 | |
B=8, M=2048, K=32 | 5602.2 | | 5008.1 | 2277.6
B=8, M=2048, K=64 | 6910.0 | | |
B=8, M=2048, K=128 | 10999.5 | 6245.5 | 6448.7 | 3121.8
B=8, M=2048, K=256 | 23657.9 | 10823.6 | 11197.1 | 5183.5
B=8, M=4096, K=128 | | 23813.0 | 24924.0 | 12123.6
B=8, M=4096, K=256 | | 42068.8 | 44655.8 | 19977.7
B=8, M=64, K=8 | | | 34.4 | 65.3
B=8, M=64, K=16 | | | 34.7 | 66.1
B=8, M=64, K=32 | | | 34.9 | 66.8
B=8, M=64, K=128 | | | 36.1 | 66.3
B=8, M=64, K=256 | | | 55.2 | 66.0
B=8, M=127, K=8 | | | 37.0 | 66.2
B=8, M=127, K=16 | | | 37.9 | 66.6
B=8, M=127, K=32 | | | 40.2 | 66.9
B=8, M=127, K=128 | | | 52.4 | 66.4
B=8, M=127, K=256 | | | 81.2 | 67.0
B=8, M=128, K=8 | | | 36.7 | 66.4
B=8, M=128, K=16 | | | 37.7 | 66.4
B=8, M=256, K=8 | | | 117.9 | 65.9
B=8, M=256, K=16 | | | 120.7 | 65.6
B=8, M=256, K=32 | | | 125.3 | 66.3
B=8, M=256, K=128 | | | 166.5 | 75.7
B=8, M=256, K=256 | | | 292.3 | 127.3
B=8, M=2048, K=8 | | | 4623.5 | 2143.2
B=8, M=2048, K=16 | | | 4738.0 | 2169.6
B=8, M=4096, K=8 | | | 17678.3 | 8092.9
B=8, M=4096, K=16 | | | 18077.2 | 8314.3
B=8, M=4096, K=32 | | | 18998.2 | 8912.3
21641fa
[------------------------------------------------------ attention -------------------------------------------------------]
| cutlass_16x128x8_siPrecacl2 | cutlass2 | cutlass2_bounds | cutlass2_bounds2 | vanilla
1 threads: ---------------------------------------------------------------------------------------------------------------
B=8, M=128, K=32 | 59.8 | | 40.2 | 34.8 | 66.7
B=8, M=128, K=64 | 66.2 | | | |
B=8, M=128, K=128 | 84.4 | 50.8 | 52.6 | 43.8 | 66.3
B=8, M=128, K=256 | 146.9 | 79.4 | 81.2 | 75.4 | 67.0
B=8, M=1024, K=32 | 1561.3 | | 1266.9 | 949.9 | 520.5
B=8, M=1024, K=64 | 1898.1 | | | |
B=8, M=1024, K=128 | 2910.1 | 1619.8 | 1652.8 | 1403.7 | 737.0
B=8, M=1024, K=256 | 6118.0 | 2702.9 | 2813.2 | 2699.2 | 1330.9
B=8, M=2048, K=32 | 5602.2 | | | |
B=8, M=2048, K=64 | 6910.0 | | | |
B=8, M=2048, K=128 | 10999.5 | 6245.5 | | |
B=8, M=2048, K=256 | 23657.9 | 10823.6 | | |
B=8, M=4096, K=128 | | 23813.0 | 24679.0 | 21202.0 | 12016.8
B=8, M=4096, K=256 | | 42068.8 | 44072.9 | 42345.7 | 19873.5
B=8, M=96, K=31 | | | 36.5 | 34.8 | 65.7
B=8, M=96, K=32 | | | 36.8 | 34.7 | 65.3
B=8, M=96, K=128 | | | 38.4 | 35.0 | 66.1
B=8, M=96, K=256 | | | 59.9 | 53.2 | 66.8
B=8, M=127, K=31 | | | 40.3 | 34.9 | 66.1
B=8, M=127, K=32 | | | 40.2 | 34.3 | 65.6
B=8, M=127, K=128 | | | 52.4 | 43.8 | 66.2
B=8, M=127, K=256 | | | 81.1 | 75.5 | 66.7
B=8, M=128, K=31 | | | 40.4 | 35.1 | 66.7
B=8, M=1024, K=31 | | | 1271.3 | 957.1 | 523.3
B=8, M=4096, K=31 | | | 18935.7 | 14102.2 | 8957.7
B=8, M=4096, K=32 | | | 18865.5 | 14000.8 | 8905.5
8f4728a56d02c629393e60438c0945ca60ef3b78
[------------------------------------------------------------------------- attention --------------------------------------------------------------------------]
| cutlass2_bounds_v100 | cutlass2_bounds_autoreg_v100 | cutlass2_bounds_reg160_v100 | cutlass2_bounds_autoreg2_v100 | vanilla
1 threads: -----------------------------------------------------------------------------------------------------------------------------------------------------
B=8, M=96, K=31 | 35.0 | 35.1 | 36.8 | 34.8 | 88.7
B=8, M=96, K=32 | 35.3 | 35.0 | 36.8 | 34.3 | 89.1
B=8, M=96, K=128 | 41.4 | 36.5 | 37.0 | 34.9 | 90.2
B=8, M=96, K=256 | 67.0 | 51.2 | 49.5 | 48.9 | 89.8
B=8, M=127, K=31 | 36.1 | 35.6 | 37.2 | 34.6 | 89.7
B=8, M=127, K=32 | 35.8 | 35.5 | 36.5 | 34.5 | 89.0
B=8, M=127, K=128 | 46.0 | 35.3 | 37.1 | 34.6 | 91.4
B=8, M=127, K=256 | 74.4 | 55.8 | 54.0 | 53.2 | 89.9
B=8, M=128, K=31 | 35.8 | 35.5 | 36.6 | 34.3 | 90.1
B=8, M=128, K=32 | 36.0 | 35.7 | 36.3 | 34.6 | 89.3
B=8, M=128, K=128 | 46.1 | 35.5 | 37.1 | 34.8 | 89.2
B=8, M=128, K=256 | 74.8 | 56.0 | 54.3 | 53.4 | 89.6
B=8, M=1024, K=31 | 1040.0 | 631.4 | 552.7 | 554.9 | 375.9
B=8, M=1024, K=32 | 1031.7 | 634.1 | 550.5 | 553.6 | 370.7
B=8, M=1024, K=128 | 1409.2 | 1006.3 | 911.2 | 897.8 | 517.3
B=8, M=1024, K=256 | 2376.6 | 1799.2 | 1661.1 | 1660.1 | 899.5
B=8, M=4096, K=31 | 13739.2 | 8341.8 | 7315.9 | 7313.5 | 3547.7
B=8, M=4096, K=32 | 13555.8 | 8325.9 | 7287.1 | 7289.8 | 3477.7
B=8, M=4096, K=128 | 18125.9 | 12822.6 | 11345.4 | 11282.0 | 7810.2
B=8, M=4096, K=256 | 30331.0 | 22973.0 | 21350.0 | 21329.5 | 13067.7
[------------------- attention -------------------]
| optimized | vanilla
1 threads: ----------------------------------------
B=7680, M=36, K=32 | 1965.7 | 694.4
B=7680, M=36, K=64 | 2343.0 | 949.1
B=7680, M=48, K=32 | 1987.7 | 879.4
B=7680, M=48, K=64 | 2475.8 | 1219.1
B=7680, M=64, K=32 | 2842.9 | 1299.9
B=7680, M=64, K=64 | 3523.2 | 1904.5
B=7680, M=96, K=32 | 4956.6 | 2292.8
B=7680, M=96, K=64 | 5952.7 | 3430.5
B=7680, M=128, K=32 | 7563.3 | 3858.8
B=7680, M=128, K=64 | 8750.9 | 5048.0
B=7680, M=192, K=32 | 18231.2 | 7653.3
B=7680, M=192, K=64 | 21890.0 | 11069.7
B=768, M=36, K=32 | 199.1 | 95.5
B=768, M=36, K=64 | 241.5 | 126.3
B=768, M=48, K=32 | 209.3 | 109.7
B=768, M=48, K=64 | 259.8 | 156.6
B=768, M=64, K=32 | 297.9 | 152.7
B=768, M=64, K=64 | 364.6 | 216.1
B=768, M=96, K=32 | 507.3 | 255.3
B=768, M=96, K=64 | 606.3 | 367.3
B=768, M=128, K=32 | 755.1 | 402.3
B=768, M=128, K=64 | 890.1 | 543.6
B=768, M=192, K=32 | 1841.1 | 798.7
B=768, M=192, K=64 | 2202.8 | 1132.5
B=1536, M=36, K=32 | 394.5 | 161.0
B=1536, M=36, K=64 | 475.0 | 220.5
B=1536, M=48, K=32 | 407.6 | 197.4
B=1536, M=48, K=64 | 506.1 | 275.6
B=1536, M=64, K=32 | 588.6 | 282.2
B=1536, M=64, K=64 | 717.4 | 404.1
B=1536, M=96, K=32 | 1002.8 | 485.3
B=1536, M=96, K=64 | 1198.1 | 714.5
B=1536, M=128, K=32 | 1508.2 | 778.0
B=1536, M=128, K=64 | 1762.5 | 1079.9
B=1536, M=192, K=32 | 3662.7 | 1557.7
B=1536, M=192, K=64 | 4386.4 | 2239.0
with kNumWarpsPerBlock=2
[------------------- attention -------------------]
| optimized | vanilla
1 threads: ----------------------------------------
B=7680, M=36, K=64 | 1842.3 | 950.1
B=768, M=192, K=64 | 2040.4 | 1133.6
B=1536, M=192, K=64 | 4151.8 | 2241.4
B=5376, M=36, K=64 | 1288.5 | 675.9
B=544, M=192, K=64 | 1434.0 | 815.1
B=512, M=192, K=64 | 1348.8 | 767.2
B=1088, M=192, K=64 | 2913.7 | 1600.2
B=1024, M=192, K=64 | 2740.3 | 1500.1
B=768, M=154, K=64 | 1606.7 | 941.4
B=1536, M=154, K=64 | 3247.5 | 1838.4
B=544, M=154, K=64 | 1134.2 | 676.1
B=512, M=154, K=64 | 1067.8 | 636.9
B=1088, M=154, K=64 | 2288.8 | 1312.3
B=1024, M=154, K=64 | 2153.4 | 1240.5
[----------------- attention ------------------]
| optimized | vanilla
1 threads: -------------------------------------
B=8, M=36, K=64 | 37.5 | 95.5
B=8, M=192, K=64 | 61.3 | 128.7
B=8, M=154, K=64 | 58.9 | 127.8
matmull1 can be improved:
[----------------------------------------------------------------- attention ------------------------------------------------------------------]
| cutlass2_bounds_v100 | cutlass2_bounds_autoreg_v100 | cutlass2_bounds_reg160_v100 | c2_nomm1_v100 | vanilla
1 threads: -------------------------------------------------------------------------------------------------------------------------------------
B=8, M=96, K=31 | 35.0 | 35.1 | 36.8 | |
B=8, M=96, K=32 | 35.3 | 35.0 | 36.8 | |
B=8, M=96, K=128 | 41.4 | 36.5 | 37.0 | |
B=8, M=96, K=256 | 67.0 | 51.2 | 49.5 | |
B=8, M=127, K=31 | 36.1 | 35.6 | 37.2 | |
B=8, M=127, K=32 | 35.8 | 35.5 | 36.5 | |
B=8, M=127, K=128 | 46.0 | 35.3 | 37.1 | |
B=8, M=127, K=256 | 74.4 | 55.8 | 54.0 | |
B=8, M=128, K=31 | 35.8 | 35.5 | 36.6 | |
B=8, M=128, K=32 | 36.0 | 35.7 | 36.3 | 34.2 | 85.0
B=8, M=128, K=128 | 46.1 | 35.5 | 37.1 | 34.5 | 85.2
B=8, M=128, K=256 | 74.8 | 56.0 | 54.3 | 38.1 | 85.5
B=8, M=1024, K=31 | 1040.0 | 631.4 | 552.7 | |
B=8, M=1024, K=32 | 1031.7 | 634.1 | 550.5 | 395.5 | 370.5
B=8, M=1024, K=128 | 1409.2 | 1006.3 | 911.2 | 474.3 | 517.2
B=8, M=1024, K=256 | 2376.6 | 1799.2 | 1661.1 | 879.1 | 899.7
B=8, M=4096, K=31 | 13739.2 | 8341.8 | 7315.9 | |
B=8, M=4096, K=32 | 13555.8 | 8325.9 | 7287.1 | 5037.9 | 3459.8
B=8, M=4096, K=128 | 18125.9 | 12822.6 | 11345.4 | 5503.3 | 7797.6
B=8, M=4096, K=256 | 30331.0 | 22973.0 | 21350.0 | 10545.0 | 13076.7
based = before optimized = b08a5c8
[---------------------------------- attention -----------------------------------]
| base | nomm1 | nomm1mi | optimized | vanilla
1 threads: -----------------------------------------------------------------------
B=7680, M=36, K=64 | 4.9 | 2.5 | 4.6 | 4.4 | 2.2
B=768, M=192, K=64 | 4.4 | 2.5 | 4.1 | 4.0 | 2.1
B=1536, M=192, K=64 | 8.7 | 5.0 | 8.1 | 8.0 | 4.1
B=5376, M=36, K=64 | 3.5 | 1.8 | 3.2 | 3.1 | 1.6
new = 9f91e09
new_w2 = new with kNumWarpsPerBlock=2
resetcutlass = without my mods to cutlass to load shared mem
[----------------------------------------------------- attention ------------------------------------------------------]
| base | new | new_w2 | new_nomm1 | new_nomm2_resetcutlass | new_nomm2 | vanilla
1 threads: -------------------------------------------------------------------------------------------------------------
B=7680, M=36, K=64 | 4.6 | 4.1 | 2.9 | 2.0 | 2.2 | 2.4 | 2.2
B=768, M=192, K=64 | 4.2 | 3.9 | 3.2 | 2.2 | 1.6 | 1.7 | 2.1
B=1536, M=192, K=64 | 8.2 | 7.7 | 6.4 | 4.4 | 3.3 | 3.5 | 4.1
B=5376, M=36, K=64 | 3.1 | 2.9 | 2.1 | 1.4 | 1.5 | 1.7 | 1.6
Potential for improvement by handling si
better:
[------------------------------------------------- attention -------------------------------------------------]
| base | no_si_ld | no_si_ld_nomods | no_si_ld_nomods_noldoutput | vanilla
1 threads: ----------------------------------------------------------------------------------------------------
B=7680, M=36, K=64 | 4.1 | 4.2 | 3.6 | 3.4 | 2.2
B=768, M=784, K=64 | 54.7 | 56.5 | 50.0 | 46.7 | 31.7
B=128, M=2048, K=128 | 77.9 | 79.9 | 70.6 | 66.1 | 46.3
cutlass batched gemm example (5):
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
int const m = 2048;
int const n = 2048;
int const k = 128;
int const batch_count = 128;
Built target 05_batched_gemm
Running array gemm
cutlass_array_sgemm --> 13758[µs]
Passed.
Running strided batched gemm
cutlass_strided_batched_sgemm --> 13571[µs]
Passed.
Running array gemm
cutlass_array_sgemm --> 13027[µs]
Passed.
Running strided batched gemm
cutlass_strided_batched_sgemm --> 12887[µs]
Passed.
after: 56c3ef3
[-------------------------- attention --------------------------]
| base_v100 | cmp_v100 | vanilla
1 threads: ------------------------------------------------------
B=1, M=20480, K=128 | 32035.9 | 31483.8 | 23315.1
B=7680, M=36, K=64 | 1697.0 | 1664.1 | 887.4
B=768, M=784, K=64 | 25137.3 | 24666.7 | 17861.5
B=128, M=2048, K=128 | 38158.1 | 37305.4 | 27872.0
tune q
[---------------------------------- attention ----------------------------------]
| q32_v100 | q16_v100 | q32_w2_v100 | vanilla
1 threads: ----------------------------------------------------------------------
B=64, M=128, K=32 | 61.5 | 55.0 | 59.7 | 93.0
B=64, M=128, K=64 | 74.8 | 67.1 | 74.4 | 95.1
B=64, M=128, K=128 | 103.8 | 97.5 | 126.5 | 94.8
B=64, M=128, K=256 | 199.7 | 193.2 | 255.2 | 143.0
B=64, M=1024, K=32 | 2624.5 | 2675.6 | 2859.2 | 1808.2
B=64, M=1024, K=64 | 3150.5 | 3351.8 | 3541.1 | 2236.5
B=64, M=1024, K=128 | 4278.5 | 4851.4 | 6130.6 | 3453.7
B=64, M=1024, K=256 | 8177.7 | 9618.6 | 11313.3 | 6076.0
B=64, M=4096, K=32 | 40869.8 | 41800.3 | 42976.4 | 29175.0
B=64, M=4096, K=64 | 48617.7 | 52187.0 | 52658.0 | 35570.7
B=64, M=4096, K=128 | 66354.5 | 73982.8 | 93704.2 | 54904.6
B=64, M=4096, K=256 | 128790.8 | 147621.5 | 178471.5 | 96703.2
B=256, M=128, K=16 | 175.5 | 164.1 | 192.6 | 120.8
B=256, M=128, K=32 | 195.8 | 183.7 | 216.6 | 142.8
B=256, M=128, K=64 | 246.0 | 235.6 | 291.1 | 198.6
B=256, M=128, K=128 | 343.9 | 345.5 | 502.9 | 326.5
B=256, M=1024, K=16 | 9468.1 | 9294.3 | 9837.2 | 6995.3
B=256, M=1024, K=32 | 10337.5 | 10553.5 | 10889.0 | 7663.4
B=256, M=1024, K=64 | 12403.2 | 13246.0 | 13361.2 | 8855.1
B=256, M=1024, K=128 | 16954.5 | 18884.8 | 23965.6 | 14248.1
B=1024, M=128, K=16 | 669.7 | 632.9 | 676.9 | 438.6
B=1024, M=128, K=32 | 732.8 | 714.1 | 748.6 | 518.9
B=1024, M=128, K=64 | 908.9 | 911.4 | 966.5 | 676.6
B=1024, M=128, K=128 | 1282.2 | 1349.3 | 1744.1 | 1106.4
B=1024, M=1024, K=16 | 37649.0 | 36837.9 | 39175.8 | 27497.3
B=1024, M=1024, K=32 | 41135.6 | 41764.0 | 43614.5 | 29859.8
B=1024, M=1024, K=64 | 49194.7 | 52581.9 | 53170.3 | 36187.4
B=1024, M=1024, K=128 | 67194.2 | 74965.7 | 97479.2 | 55582.8
f16v0 -> everything in f16 (gives wrong result and nans/infs)
f16v1_accf32 -> input/output in f16, everything else in f32 (no nan, close to what we expected)
f16v2_tcmm1 -> use tensorcore for first matmull (f32 = f16*f16 + f32). Had to modify kQueriesPerBlock=64
(instead of 32) and kNumWarpsPerBlock=2
(instead of 4 to use less sharedmem)
[-------------------------------------------- attention (use_attn_bias=False) --------------------------------------------]
| f16v2_q64w2 | f16v0 | vanilla | f16v1_accf32 | f16v2_tcmm1
1 threads: ----------------------------------------------------------------------------------------------------------------
(Quadro_GP100) f32 B=32, M=128, K=16 | | 73.3 | 111.9 | 73.2 |
f16 B=32, M=128, K=16 | | 188.3 | 74.3 | 57.0 |
f32 B=32, M=128, K=32 | | 78.7 | 74.0 | 78.6 |
f16 B=32, M=128, K=32 | | 191.1 | 74.5 | 61.1 |
f32 B=32, M=128, K=128 | | 136.6 | 83.5 | 136.7 |
f16 B=32, M=128, K=128 | | 215.1 | 75.4 | 98.0 |
f32 B=32, M=512, K=16 | | 833.3 | 504.0 | 833.8 |
f16 B=32, M=512, K=16 | | 2920.8 | 438.2 | 583.2 |
f32 B=32, M=512, K=32 | | 907.5 | 530.5 | 908.3 |
f16 B=32, M=512, K=32 | | 2892.3 | 468.1 | 630.6 |
f32 B=32, M=512, K=128 | | 1394.1 | 759.2 | 1392.6 |
f16 B=32, M=512, K=128 | | 3150.9 | 692.3 | 1006.0 |
f32 B=32, M=1024, K=16 | | 3134.4 | 1828.0 | 3132.4 |
f16 B=32, M=1024, K=16 | | 10324.8 | 1638.7 | 2210.9 |
f32 B=32, M=1024, K=32 | | 3403.3 | 1947.8 | 3402.6 |
f16 B=32, M=1024, K=32 | | 10157.6 | 1772.1 | 2457.0 |
f32 B=32, M=1024, K=128 | | 5134.3 | 2832.5 | 5127.5 |
f16 B=32, M=1024, K=128 | | 11435.0 | 2613.7 | 3900.7 |
f32 B=256, M=128, K=16 | | 450.7 | 261.8 | 450.2 |
f16 B=256, M=128, K=16 | | 1344.8 | 226.3 | 307.1 |
f32 B=256, M=128, K=32 | | 499.9 | 290.7 | 499.2 |
f16 B=256, M=128, K=32 | | 1357.5 | 248.4 | 334.9 |
f32 B=256, M=128, K=128 | | 865.4 | 457.6 | 864.6 |
f16 B=256, M=128, K=128 | | 1515.9 | 387.1 | 574.0 |
f32 B=256, M=512, K=16 | | 6292.5 | 3508.8 | 6281.7 |
f16 B=256, M=512, K=16 | | 18977.6 | 3050.2 | 4291.6 |
f32 B=256, M=512, K=32 | | 6822.5 | 3783.0 | 6818.3 |
f16 B=256, M=512, K=32 | | 18976.3 | 3324.1 | 4746.6 |
f32 B=256, M=512, K=128 | | 10467.9 | 5656.4 | 10456.6 |
f16 B=256, M=512, K=128 | | 21502.2 | 5028.3 | 7758.5 |
f32 B=256, M=1024, K=16 | | 24512.5 | 13606.6 | 24488.6 |
f16 B=256, M=1024, K=16 | | 74444.1 | 12081.0 | 16968.4 |
f32 B=256, M=1024, K=32 | | 26569.0 | 14697.7 | 26532.2 |
f16 B=256, M=1024, K=32 | | 75113.9 | 13117.4 | 18574.7 |
f32 B=256, M=1024, K=128 | | 40343.8 | 21603.1 | 39809.1 |
f16 B=256, M=1024, K=128 | | 85245.3 | 19681.6 | 30244.1 |
(Tesla_V100_SXM2_16GB) f16 B=32, M=128, K=16 | 84.8 | 116.9 | 101.5 | 46.5 | 81.7
f16 B=32, M=128, K=32 | 91.8 | 117.2 | 105.6 | 47.5 | 84.0
f16 B=32, M=128, K=128 | 174.5 | 135.1 | 113.9 | 54.3 | 145.0
f16 B=32, M=512, K=16 | 648.0 | 1631.7 | 115.8 | 346.4 | 658.7
f16 B=32, M=512, K=32 | 715.4 | 1535.2 | 110.3 | 379.4 | 677.2
f16 B=32, M=512, K=128 | 1439.5 | 1943.2 | 153.9 | 589.4 | 1140.9
f16 B=32, M=1024, K=16 | 2157.9 | 5831.8 | 460.4 | 1220.9 | 2396.5
f16 B=32, M=1024, K=32 | 2427.7 | 5860.0 | 476.3 | 1329.1 | 2468.5
f16 B=32, M=1024, K=128 | 5023.6 | 6456.9 | 565.7 | 2080.1 | 4142.7
f16 B=256, M=128, K=16 | 283.4 | 793.0 | 101.9 | 154.7 | 331.1
f16 B=256, M=128, K=32 | 310.8 | 795.6 | 99.8 | 170.9 | 341.0
f16 B=256, M=128, K=128 | 671.5 | 912.9 | 133.9 | 289.3 | 637.2
f16 B=256, M=512, K=16 | 3633.4 | 10531.5 | 789.4 | 2208.8 | 4289.5
f16 B=256, M=512, K=32 | 4128.1 | 10622.5 | 830.5 | 2401.2 | 4377.2
f16 B=256, M=512, K=128 | 8808.2 | 11734.8 | 1142.8 | 3903.4 | 7505.2
f16 B=256, M=1024, K=16 | 14273.9 | 41560.8 | 3339.2 | 8536.4 | 16685.1
f16 B=256, M=1024, K=32 | 16185.3 | 41815.5 | 3445.3 | 9294.8 | 16968.8
f16 B=256, M=1024, K=128 | 34472.9 | 45577.9 | 4216.7 | 15252.3 | 28662.0
f32 B=32, M=128, K=16 | | 51.0 | 128.3 | 47.1 |
f32 B=32, M=128, K=32 | | 44.8 | 99.0 | 47.0 |
f32 B=32, M=128, K=128 | | 55.0 | 107.9 | 54.8 |
f32 B=32, M=512, K=16 | | 369.9 | 238.7 | 369.1 |
f32 B=32, M=512, K=32 | | 387.1 | 252.6 | 387.9 |
f32 B=32, M=512, K=128 | | 673.1 | 485.9 | 672.7 |
f32 B=32, M=1024, K=16 | | 1285.4 | 773.7 | 1285.8 |
f32 B=32, M=1024, K=32 | | 1403.7 | 827.4 | 1405.3 |
f32 B=32, M=1024, K=128 | | 2279.4 | 1778.5 | 2282.1 |
f32 B=256, M=128, K=16 | | 176.3 | 118.7 | 176.4 |
f32 B=256, M=128, K=32 | | 197.0 | 142.7 | 197.7 |
f32 B=256, M=128, K=128 | | 345.6 | 334.1 | 344.7 |
f32 B=256, M=512, K=16 | | 2428.0 | 1760.7 | 2429.0 |
f32 B=256, M=512, K=32 | | 2664.4 | 1914.1 | 2664.3 |
f32 B=256, M=512, K=128 | | 4425.9 | 3584.2 | 4423.0 |
f32 B=256, M=1024, K=16 | | 9490.1 | 7022.7 | 9494.1 |
f32 B=256, M=1024, K=32 | | 10378.7 | 7652.1 | 10378.5 |
f32 B=256, M=1024, K=128 | | 17011.5 | 14334.8 | 16975.2 |
Times are in microseconds (us).
On A100 (f16v1_accf32
):
[----------- attention (use_attn_bias=False) ------------]
| f16v1_accf32 | vanilla
1 threads: -----------------------------------------------
f16 B=32, M=128, K=16 | 32.6 | 82.6
f32 B=32, M=128, K=16 | 32.3 | 81.9
f16 B=32, M=128, K=32 | 34.7 | 67.1
f32 B=32, M=128, K=32 | 34.3 | 81.4
f16 B=32, M=128, K=128 | 50.2 | 67.2
f32 B=32, M=128, K=128 | 49.1 | 80.6
f16 B=32, M=512, K=16 | 235.4 | 70.1
f32 B=32, M=512, K=16 | 230.5 | 126.1
f16 B=32, M=512, K=32 | 254.5 | 71.3
f32 B=32, M=512, K=32 | 249.6 | 132.1
f16 B=32, M=512, K=128 | 380.9 | 88.7
f32 B=32, M=512, K=128 | 368.2 | 166.5
f16 B=32, M=1024, K=16 | 877.4 | 268.8
f32 B=32, M=1024, K=16 | 857.1 | 441.7
f16 B=32, M=1024, K=32 | 948.9 | 269.8
f32 B=32, M=1024, K=32 | 931.3 | 453.4
f16 B=32, M=1024, K=128 | 1428.5 | 311.4
f32 B=32, M=1024, K=128 | 1393.2 | 538.0
f16 B=256, M=128, K=16 | 122.6 | 67.6
f32 B=256, M=128, K=16 | 119.2 | 81.6
f16 B=256, M=128, K=32 | 131.8 | 68.4
f32 B=256, M=128, K=32 | 128.9 | 82.6
f16 B=256, M=128, K=128 | 203.3 | 67.7
f32 B=256, M=128, K=128 | 201.2 | 158.8
f16 B=256, M=512, K=16 | 1628.9 | 475.4
f32 B=256, M=512, K=16 | 1586.2 | 857.5
f16 B=256, M=512, K=32 | 1752.5 | 494.7
f32 B=256, M=512, K=32 | 1725.1 | 897.4
f16 B=256, M=512, K=128 | 2700.5 | 614.1
f32 B=256, M=512, K=128 | 2637.0 | 1179.4
f16 B=256, M=1024, K=16 | 6410.0 | 1869.8
f32 B=256, M=1024, K=16 | 6253.6 | 3200.0
f16 B=256, M=1024, K=32 | 6941.5 | 1904.8
f32 B=256, M=1024, K=32 | 6811.3 | 3302.3
f16 B=256, M=1024, K=128 | 10521.9 | 2230.6
f32 B=256, M=1024, K=128 | 10273.3 | 4033.5
just the matmull: