nod-ai / SHARK-ModelDev

Unified compiler/runtime for interfacing with PyTorch Dynamo.
Apache License 2.0
94 stars 48 forks source link

[matmul] Increase L1 cache bandwidth #811

Open kuhar opened 2 months ago

kuhar commented 2 months ago

The L1 data cache on MI300 is divided into 4 sets of up to 64 entries, with the cache line of 128B. To utilize the full bandwidth, the kernel has to engage all four sets. However, with large matrix operands being aligned to the 512B boundary, reads along the parallel dimensions hit the same cache sets, resulting in hot-spotting.

To combat this, we can either pad the underlying allocation so that each matrix operand row starts at a different cache set, or for a different address-to-set mapping with buffer_load instructions. The former is a graph-level optimization, the latter is intra-kernel.

Below are the results of applying tensor padding and buffer padding by hand for mmt_4864x4096x4096_f16xf16xf32: Padding [elems] 0 32 64 128 192
Tile sizes [M, N, K] [64, 128, 64] [64, 128, 32] [64, 128, 64] [64, 128, 64] [64, 128, 64]
Time [us] 2180 2962 2080 2126 2132
Buffer Padding [B] 0 64 128 256 384
Tile sizes [M, N, K] [64, 128, 64] [64, 128, 64] [64, 128, 64] [64, 128, 64] [64, 128, 64]
Time [us] 2178 2705 2040 2056 2030

Padding with a single cache line improves performance by 6.5% when the tensor shape remains the same, and 4.5% if we pad tensors and do the unnecessary computation. Based on these results, we would like to implement the former in the IREE compiler.

kuhar commented 2 months ago

cc: @MaheshRavishankar @qedawkins @antiagainst

kuhar commented 1 month ago

I implemented a crude prototype by adding padding just before dispatch region formation. This results in excessive slow_memcpy dispatches but generates the code I'd expect at the kernel level.

I evaluated this against SDXL fp16 on mi210: shape no padding [us] 128B padding [us] Delta [%]
2048x10240x1280_f16xf16xf32 1172.912717 1169.9382 -0.25
1024x64x1280_f16xf16xf32 117.32075 116.3258958 -0.85
2048x1280x5120_f16xf16xf32 325.3763833 303.8777 -6.61
2048x1280x1280_f16xf16xf32 95.69621138 94.84761789 -0.89
8192x5120x640_f16xf16xf32 581.7872 579.3211 -0.42
4096x64x640_f16xf16xf32 118.28935 118.10915 -0.15
64x64x2048_f16xf16xf32 34.70054167 35.67100833 2.8
8192x640x2560_f16xf16xf32 278.7596 271.9727 -2.43
8192x640x640_f16xf16xf32 94.70981818 95.08172727 0.39
64x64x2048_f16xf16xf32 34.8143 35.083 0.77
2048x1280x1280_f16xf16xf32 92.5295 92.529 0
8192x640x640_f16xf16xf32 87.0848 86.8164 -0.31
8192x640x640_f16xf16xf32 102.173 101.6845 -0.48
2048x1280x1280_f16xf16xf32 101.4405 99.8535 -1.56
2x1280x1280_f16xf16xf32 19.577 19.683625 0.54
2048x1280x1280_f16xf16xf32 102.173 98.389 -3.7
8192x640x640_f16xf16xf32 94.605 94.726 0.13
2x640x1280_f16xf16xf32 16.3084 17.4316 6.89
2x320x1280_f16xf16xf32 14.3066 14.3556 0.34
2x1280x2816_f16xf16xf32 69.092 68.604 -0.71
2x1280x1280_f16xf16xf32 21.606 21.973 1.7
2x1280x320_f16xf16xf32 19.531 20.019 2.5

The maximum improvement is ~6.6%, which is in line with the isolated benchmark results. However, there are some regressions.

kuhar commented 1 month ago

The biggest regression are with the skinny matmuls, but this is expected and can be filtered out with a smarter heuristic.

kuhar commented 1 month ago

Here are results for other padding amounts.

64 B (half the cache line)

shape no padding [us] 64B padding [us] Delta [%]
2048x10240x1280_f16xf16xf32 1172.912717 1172.7132 -0.02
1024x64x1280_f16xf16xf32 117.32075 116.7216625 -0.51
2048x1280x5120_f16xf16xf32 325.3763833 303.4423 -6.74
2048x1280x1280_f16xf16xf32 95.69621138 95.02428455 -0.7
8192x5120x640_f16xf16xf32 581.7872 581.2744 -0.09
4096x64x640_f16xf16xf32 118.28935 117.425425 -0.73
64x64x2048_f16xf16xf32 34.70054167 36.38305 4.85
8192x640x2560_f16xf16xf32 278.7596 272.8028 -2.14
8192x640x640_f16xf16xf32 94.70981818 94.47140909 -0.25
64x64x2048_f16xf16xf32 34.8143 35.66895 2.45
2048x1280x1280_f16xf16xf32 92.5295 91.75633333 -0.84
8192x640x640_f16xf16xf32 87.0848 86.2304 -0.98
8192x640x640_f16xf16xf32 102.173 103.027 0.84
2048x1280x1280_f16xf16xf32 101.4405 102.173 0.72
2x1280x1280_f16xf16xf32 19.577 20.294375 3.66
2048x1280x1280_f16xf16xf32 102.173 98.632 -3.47
8192x640x640_f16xf16xf32 94.605 93.994 -0.65
2x640x1280_f16xf16xf32 16.3084 16.4548 0.9
2x320x1280_f16xf16xf32 14.3066 14.2574 -0.34
2x1280x2816_f16xf16xf32 69.092 69.092 0
2x1280x1280_f16xf16xf32 21.606 21.484 -0.56
2x1280x320_f16xf16xf32 19.531 19.775 1.25

256 B (two cache lines)

shape no padding [us] 256B padding [us] Delta [%]
2048x10240x1280_f16xf16xf32 1172.912717 1169.146733 -0.32
1024x64x1280_f16xf16xf32 117.32075 116.0766792 -1.06
2048x1280x5120_f16xf16xf32 325.3763833 305.6316333 -6.07
2048x1280x1280_f16xf16xf32 95.69621138 95.11960976 -0.6
8192x5120x640_f16xf16xf32 581.7872 578.2715 -0.6
4096x64x640_f16xf16xf32 118.28935 117.331025 -0.81
64x64x2048_f16xf16xf32 34.70054167 36.2243 4.39
8192x640x2560_f16xf16xf32 278.7596 273.6083 -1.85
8192x640x640_f16xf16xf32 94.70981818 93.97736364 -0.77
64x64x2048_f16xf16xf32 34.8143 35.00355 0.54
2048x1280x1280_f16xf16xf32 92.5295 91.7155 -0.88
8192x640x640_f16xf16xf32 87.0848 86.548 -0.62
8192x640x640_f16xf16xf32 102.173 102.6615 0.48
2048x1280x1280_f16xf16xf32 101.4405 100.403 -1.02
2x1280x1280_f16xf16xf32 19.577 19.760125 0.94
2048x1280x1280_f16xf16xf32 102.173 98.877 -3.23
8192x640x640_f16xf16xf32 94.605 94.971 0.39
2x640x1280_f16xf16xf32 16.3084 16.4306 0.75
2x320x1280_f16xf16xf32 14.3066 14.2334 -0.51
2x1280x2816_f16xf16xf32 69.092 68.725 -0.53
2x1280x1280_f16xf16xf32 21.606 21.973 1.7
2x1280x320_f16xf16xf32 19.531 19.531 0
kuhar commented 1 month ago

I modified the heuristic to only apply padding where it makes sense based on the tensor types (not too small) and added a column that shows which matmuls were actually padded. This should make it clearer what the difference is compared to no padding:

shape M N K no padding [us] 128B padding [us] Delta [%] Padding applied?
2048x10240x1280_f16xf16xf32 2048 10240 1280 1172.912717 1171.736767 -0.1 TRUE
1024x64x1280_f16xf16xf32 1024 64 1280 117.32075 116.4471917 -0.74 FALSE
2048x1280x5120_f16xf16xf32 2048 1280 5120 325.3763833 305.9314667 -5.98 TRUE
2048x1280x1280_f16xf16xf32 2048 1280 1280 95.69621138 94.98953659 -0.74 TRUE
8192x5120x640_f16xf16xf32 8192 5120 640 581.7872 579.0404 -0.47 FALSE
4096x64x640_f16xf16xf32 4096 64 640 118.28935 118.357825 0.06 FALSE
64x64x2048_f16xf16xf32 64 64 2048 34.70054167 34.84596667 0.42 FALSE
8192x640x2560_f16xf16xf32 8192 640 2560 278.7596 271.8934 -2.46 TRUE
8192x640x640_f16xf16xf32 8192 640 640 94.70981818 94.99554545 0.3 FALSE
64x64x2048_f16xf16xf32 64 64 2048 34.8143 35.06775 0.73 FALSE
2048x1280x1280_f16xf16xf32 2048 1280 1280 92.5295 91.71533333 -0.88 TRUE
8192x640x640_f16xf16xf32 8192 640 640 87.0848 86.6454 -0.5 FALSE
8192x640x640_f16xf16xf32 8192 640 640 102.173 101.379 -0.78 FALSE
2048x1280x1280_f16xf16xf32 2048 1280 1280 101.4405 100.6165 -0.81 TRUE
2x1280x1280_f16xf16xf32 2 1280 1280 19.577 19.5235 -0.27 FALSE
2048x1280x1280_f16xf16xf32 2048 1280 1280 102.173 94.605 -7.41 TRUE
8192x640x640_f16xf16xf32 8192 640 640 94.605 93.262 -1.42 FALSE
2x640x1280_f16xf16xf32 2 640 1280 16.3084 16.3452 0.23 FALSE
2x320x1280_f16xf16xf32 2 320 1280 14.3066 14.2332 -0.51 FALSE
2x1280x2816_f16xf16xf32 2 1280 2816 69.092 69.458 0.53 FALSE
2x1280x1280_f16xf16xf32 2 1280 1280 21.606 21.424 -0.84 FALSE
2x1280x320_f16xf16xf32 2 1280 320 19.531 19.714 0.94 FALSE
kuhar commented 1 month ago

The wip implementation used to collect these numbers is on my fork: https://github.com/kuhar/iree/tree/pad-to-alloc (also requires a small hack in the tensor.extract_slice fold in mlir).

After talking with @MaheshRavishankar, there are two ways implement this more properly:

  1. Write a pass to add padding after dispatch region formation (across regions).

  2. Use materialize encoding and make it parallel to the data tiling work.

  3. is more hacky than 2., so we can use it the next time we have a fast approaching deadline, otherwise we would prefer to go with 2.

This applies not only to SDXL but also LLMs and matvec kernels.

Eventually, this should be superseded by full data tiling where we can put the 'column' data in a linear order and not worry about the exact stride across rows.

kuhar commented 2 weeks ago

@hanhanW FYI