Closed bjacob closed 4 months ago
@bjacob and @lialan can we discuss this a bit more tomorrow. I know this was filed a couple of weeks ago, but I only got around to looking at this (following from the PR that was sent out). I just want to clarify somethings with the change from round_dims_to
from an array to a single scalar. In my mind having an array of round_dims_to
makes more sense, and I'd rather drop the matmul_narrow_M/N cause that is not generic enough IMO.
Superseded by https://github.com/iree-org/iree/issues/17729.
Context: #17545, the rebasing of #16890 past #17077, has been difficult. It comes down to two different fields of Encoding attributes that have partially overlapping and fuzzy semantics. This Issue is about resolving all that, and along the way, completing an aspect of the intended design (to avoid over-allocating certain buffers) that was not yet implemented.
The two fields of Encodind attribute that we are talking about are
matmul_narrow_{M,N}
fields, which preexist and which #17545 is concerned with.round_dims_to
field introduced in #17077.The current semantics of these fields are that:
matmul_narrow_{M,N}
fields are just hints that this matmul has some narrow dimension, which may affect tile size selection (including for matrix operands where this narrow dimension doesn't participate; for instance, a narrow-M case likevecmat
can still lead to a different tile choice for the RHS matrix, whose shape does not involve the M dimension).round_dims_to
field is an array attribute, enumerating the dimensions in the order of the iterators, e.g. [B,] M, N, K. It informs Stream of the maximum tile sizes that this matmul may need padding of its operands for, and it's used to ensure that buffer allocations are large enough to accomodate that padding.At the moment,
round_dims_to
array entries are all initialized to the samepadFactor
value given as a pass option. So the potential benefit of having this as an array (adjusting this padding amount for narrow dimensions) is not yet reaped, while the cost (having to correctly handle this array in things like what #17545 is doing) is being paid already.In fact, if we started populating
round_dims_to
with narrower values for narrow dimensions, we would be encoding the information of "this is a narrow dimension" twice, inround_dims_to
and in thematmul_narrow_{M,N}
attribute.Proposal:
round_dims_to
tomax_padding
. This makes it clear what it is used for.max_padding
to be a single integer attribute, not an array. Its meaning is "the general-case padding amount, outside of any narrow-dimension cases".getRoundDimsToArray()
value, also check thematmul_narrow_{M,N}
attribute. If either is defined, let that override themax_padding
value, just rounded up to the next power of two. Example: ifmax_padding=16
andmatmul_narrow_M=3
, roundmatmul_narrow_M
to the next power of two, which is 4, and use that instead ofmax_padding
for the M dimension.padFactor
option value passed down to there, and use that instead.enumerateMatmulTile*
functions returns a list ofTileMxNxK
triples where the M values (the first entries in the triple) are powers of two including all smaller powers of two down to 1. With the changes that we are discussing here, this is becoming a hard requirement: this is what ensures that the M dimension never gets rounded to more than just the next power of two. So, in the top-levelenumerateMatmulTileMxNxK
function, before returning the value, we should probablyassert
that that requirement is satisfied, to catch it if we even forgot about it.round_dims_to
value 16 and the fact that in some cases inenumerateMatmulTileMxNxK
we returnTileMxNxK
values exceeding 16. These tiles are being discarded at the moment since #17077 was merged: https://github.com/iree-org/iree/blob/14fd6acdd6a942ba7eef0e966a5808f39b34eef2/compiler/src/iree/compiler/Codegen/Common/CPU/CPUMaterializeEncodingPass.cpp#L384-L393 . Now that (thanks to step 3 above) we are not rounding narrow dimensions by much anymore, it doesn't cost as much anymore to increase thatround_dims_to
(now calledmax_padding
) value a big. So my suggested trade-off would be: increase the padFactor used in SetEncoding from 16 to 32, which is needed for the tiles that we really care about; and in CPUMaterializeEncoding pass, for all tiles enumerated inenumerateMatmulTile*
, clamp all values to maximum 32 so that these tiles don't get discarded anymore. The above-linked code doing debug-logging andcontinue;
can then become an error (propagate to caller).