Open kuhar opened 2 months ago
cc: @MaheshRavishankar @qedawkins @antiagainst
I implemented a crude prototype by adding padding just before dispatch region formation. This results in excessive slow_memcpy dispatches but generates the code I'd expect at the kernel level.
I evaluated this against SDXL fp16 on mi210: shape | no padding [us] | 128B padding [us] | Delta [%] |
---|---|---|---|
2048x10240x1280_f16xf16xf32 | 1172.912717 | 1169.9382 | -0.25 |
1024x64x1280_f16xf16xf32 | 117.32075 | 116.3258958 | -0.85 |
2048x1280x5120_f16xf16xf32 | 325.3763833 | 303.8777 | -6.61 |
2048x1280x1280_f16xf16xf32 | 95.69621138 | 94.84761789 | -0.89 |
8192x5120x640_f16xf16xf32 | 581.7872 | 579.3211 | -0.42 |
4096x64x640_f16xf16xf32 | 118.28935 | 118.10915 | -0.15 |
64x64x2048_f16xf16xf32 | 34.70054167 | 35.67100833 | 2.8 |
8192x640x2560_f16xf16xf32 | 278.7596 | 271.9727 | -2.43 |
8192x640x640_f16xf16xf32 | 94.70981818 | 95.08172727 | 0.39 |
64x64x2048_f16xf16xf32 | 34.8143 | 35.083 | 0.77 |
2048x1280x1280_f16xf16xf32 | 92.5295 | 92.529 | 0 |
8192x640x640_f16xf16xf32 | 87.0848 | 86.8164 | -0.31 |
8192x640x640_f16xf16xf32 | 102.173 | 101.6845 | -0.48 |
2048x1280x1280_f16xf16xf32 | 101.4405 | 99.8535 | -1.56 |
2x1280x1280_f16xf16xf32 | 19.577 | 19.683625 | 0.54 |
2048x1280x1280_f16xf16xf32 | 102.173 | 98.389 | -3.7 |
8192x640x640_f16xf16xf32 | 94.605 | 94.726 | 0.13 |
2x640x1280_f16xf16xf32 | 16.3084 | 17.4316 | 6.89 |
2x320x1280_f16xf16xf32 | 14.3066 | 14.3556 | 0.34 |
2x1280x2816_f16xf16xf32 | 69.092 | 68.604 | -0.71 |
2x1280x1280_f16xf16xf32 | 21.606 | 21.973 | 1.7 |
2x1280x320_f16xf16xf32 | 19.531 | 20.019 | 2.5 |
The maximum improvement is ~6.6%, which is in line with the isolated benchmark results. However, there are some regressions.
The biggest regression are with the skinny matmuls, but this is expected and can be filtered out with a smarter heuristic.
Here are results for other padding amounts.
shape | no padding [us] | 64B padding [us] | Delta [%] |
---|---|---|---|
2048x10240x1280_f16xf16xf32 | 1172.912717 | 1172.7132 | -0.02 |
1024x64x1280_f16xf16xf32 | 117.32075 | 116.7216625 | -0.51 |
2048x1280x5120_f16xf16xf32 | 325.3763833 | 303.4423 | -6.74 |
2048x1280x1280_f16xf16xf32 | 95.69621138 | 95.02428455 | -0.7 |
8192x5120x640_f16xf16xf32 | 581.7872 | 581.2744 | -0.09 |
4096x64x640_f16xf16xf32 | 118.28935 | 117.425425 | -0.73 |
64x64x2048_f16xf16xf32 | 34.70054167 | 36.38305 | 4.85 |
8192x640x2560_f16xf16xf32 | 278.7596 | 272.8028 | -2.14 |
8192x640x640_f16xf16xf32 | 94.70981818 | 94.47140909 | -0.25 |
64x64x2048_f16xf16xf32 | 34.8143 | 35.66895 | 2.45 |
2048x1280x1280_f16xf16xf32 | 92.5295 | 91.75633333 | -0.84 |
8192x640x640_f16xf16xf32 | 87.0848 | 86.2304 | -0.98 |
8192x640x640_f16xf16xf32 | 102.173 | 103.027 | 0.84 |
2048x1280x1280_f16xf16xf32 | 101.4405 | 102.173 | 0.72 |
2x1280x1280_f16xf16xf32 | 19.577 | 20.294375 | 3.66 |
2048x1280x1280_f16xf16xf32 | 102.173 | 98.632 | -3.47 |
8192x640x640_f16xf16xf32 | 94.605 | 93.994 | -0.65 |
2x640x1280_f16xf16xf32 | 16.3084 | 16.4548 | 0.9 |
2x320x1280_f16xf16xf32 | 14.3066 | 14.2574 | -0.34 |
2x1280x2816_f16xf16xf32 | 69.092 | 69.092 | 0 |
2x1280x1280_f16xf16xf32 | 21.606 | 21.484 | -0.56 |
2x1280x320_f16xf16xf32 | 19.531 | 19.775 | 1.25 |
shape | no padding [us] | 256B padding [us] | Delta [%] |
---|---|---|---|
2048x10240x1280_f16xf16xf32 | 1172.912717 | 1169.146733 | -0.32 |
1024x64x1280_f16xf16xf32 | 117.32075 | 116.0766792 | -1.06 |
2048x1280x5120_f16xf16xf32 | 325.3763833 | 305.6316333 | -6.07 |
2048x1280x1280_f16xf16xf32 | 95.69621138 | 95.11960976 | -0.6 |
8192x5120x640_f16xf16xf32 | 581.7872 | 578.2715 | -0.6 |
4096x64x640_f16xf16xf32 | 118.28935 | 117.331025 | -0.81 |
64x64x2048_f16xf16xf32 | 34.70054167 | 36.2243 | 4.39 |
8192x640x2560_f16xf16xf32 | 278.7596 | 273.6083 | -1.85 |
8192x640x640_f16xf16xf32 | 94.70981818 | 93.97736364 | -0.77 |
64x64x2048_f16xf16xf32 | 34.8143 | 35.00355 | 0.54 |
2048x1280x1280_f16xf16xf32 | 92.5295 | 91.7155 | -0.88 |
8192x640x640_f16xf16xf32 | 87.0848 | 86.548 | -0.62 |
8192x640x640_f16xf16xf32 | 102.173 | 102.6615 | 0.48 |
2048x1280x1280_f16xf16xf32 | 101.4405 | 100.403 | -1.02 |
2x1280x1280_f16xf16xf32 | 19.577 | 19.760125 | 0.94 |
2048x1280x1280_f16xf16xf32 | 102.173 | 98.877 | -3.23 |
8192x640x640_f16xf16xf32 | 94.605 | 94.971 | 0.39 |
2x640x1280_f16xf16xf32 | 16.3084 | 16.4306 | 0.75 |
2x320x1280_f16xf16xf32 | 14.3066 | 14.2334 | -0.51 |
2x1280x2816_f16xf16xf32 | 69.092 | 68.725 | -0.53 |
2x1280x1280_f16xf16xf32 | 21.606 | 21.973 | 1.7 |
2x1280x320_f16xf16xf32 | 19.531 | 19.531 | 0 |
I modified the heuristic to only apply padding where it makes sense based on the tensor types (not too small) and added a column that shows which matmuls were actually padded. This should make it clearer what the difference is compared to no padding:
shape | M | N | K | no padding [us] | 128B padding [us] | Delta [%] | Padding applied? |
---|---|---|---|---|---|---|---|
2048x10240x1280_f16xf16xf32 | 2048 | 10240 | 1280 | 1172.912717 | 1171.736767 | -0.1 | TRUE |
1024x64x1280_f16xf16xf32 | 1024 | 64 | 1280 | 117.32075 | 116.4471917 | -0.74 | FALSE |
2048x1280x5120_f16xf16xf32 | 2048 | 1280 | 5120 | 325.3763833 | 305.9314667 | -5.98 | TRUE |
2048x1280x1280_f16xf16xf32 | 2048 | 1280 | 1280 | 95.69621138 | 94.98953659 | -0.74 | TRUE |
8192x5120x640_f16xf16xf32 | 8192 | 5120 | 640 | 581.7872 | 579.0404 | -0.47 | FALSE |
4096x64x640_f16xf16xf32 | 4096 | 64 | 640 | 118.28935 | 118.357825 | 0.06 | FALSE |
64x64x2048_f16xf16xf32 | 64 | 64 | 2048 | 34.70054167 | 34.84596667 | 0.42 | FALSE |
8192x640x2560_f16xf16xf32 | 8192 | 640 | 2560 | 278.7596 | 271.8934 | -2.46 | TRUE |
8192x640x640_f16xf16xf32 | 8192 | 640 | 640 | 94.70981818 | 94.99554545 | 0.3 | FALSE |
64x64x2048_f16xf16xf32 | 64 | 64 | 2048 | 34.8143 | 35.06775 | 0.73 | FALSE |
2048x1280x1280_f16xf16xf32 | 2048 | 1280 | 1280 | 92.5295 | 91.71533333 | -0.88 | TRUE |
8192x640x640_f16xf16xf32 | 8192 | 640 | 640 | 87.0848 | 86.6454 | -0.5 | FALSE |
8192x640x640_f16xf16xf32 | 8192 | 640 | 640 | 102.173 | 101.379 | -0.78 | FALSE |
2048x1280x1280_f16xf16xf32 | 2048 | 1280 | 1280 | 101.4405 | 100.6165 | -0.81 | TRUE |
2x1280x1280_f16xf16xf32 | 2 | 1280 | 1280 | 19.577 | 19.5235 | -0.27 | FALSE |
2048x1280x1280_f16xf16xf32 | 2048 | 1280 | 1280 | 102.173 | 94.605 | -7.41 | TRUE |
8192x640x640_f16xf16xf32 | 8192 | 640 | 640 | 94.605 | 93.262 | -1.42 | FALSE |
2x640x1280_f16xf16xf32 | 2 | 640 | 1280 | 16.3084 | 16.3452 | 0.23 | FALSE |
2x320x1280_f16xf16xf32 | 2 | 320 | 1280 | 14.3066 | 14.2332 | -0.51 | FALSE |
2x1280x2816_f16xf16xf32 | 2 | 1280 | 2816 | 69.092 | 69.458 | 0.53 | FALSE |
2x1280x1280_f16xf16xf32 | 2 | 1280 | 1280 | 21.606 | 21.424 | -0.84 | FALSE |
2x1280x320_f16xf16xf32 | 2 | 1280 | 320 | 19.531 | 19.714 | 0.94 | FALSE |
The wip implementation used to collect these numbers is on my fork: https://github.com/kuhar/iree/tree/pad-to-alloc (also requires a small hack in the tensor.extract_slice
fold in mlir).
After talking with @MaheshRavishankar, there are two ways implement this more properly:
Write a pass to add padding after dispatch region formation (across regions).
Use materialize encoding and make it parallel to the data tiling work.
is more hacky than 2., so we can use it the next time we have a fast approaching deadline, otherwise we would prefer to go with 2.
This applies not only to SDXL but also LLMs and matvec kernels.
Eventually, this should be superseded by full data tiling where we can put the 'column' data in a linear order and not worry about the exact stride across rows.
@hanhanW FYI
The L1 data cache on MI300 is divided into 4 sets of up to 64 entries, with the cache line of 128B. To utilize the full bandwidth, the kernel has to engage all four sets. However, with large matrix operands being aligned to the 512B boundary, reads along the parallel dimensions hit the same cache sets, resulting in hot-spotting.
To combat this, we can either pad the underlying allocation so that each matrix operand row starts at a different cache set, or for a different address-to-set mapping with buffer_load instructions. The former is a graph-level optimization, the latter is intra-kernel.
mmt_4864x4096x4096_f16xf16xf32
: Padding [elems]Padding with a single cache line improves performance by 6.5% when the tensor shape remains the same, and 4.5% if we pad tensors and do the unnecessary computation. Based on these results, we would like to implement the former in the IREE compiler.