iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.5k stars 557 forks source link

Support for data tiling on other (GPU/LLVMGPU/SPIR-V) backends #16933

Open qedawkins opened 3 months ago

qedawkins commented 3 months ago

Current overview

This is intended to be a tracking issue for adding support for data tiling on GPU backends. "Data tiling" here is being used to describe strategies for reorganizing the layout of tensors in memory to allow for better access patterns. So far this has been built out only for LLVMCPU, however the same principles should apply to GPU backends. The general flow for data tiling (targeting mmt4d and optionally ukernels) on CPU is summarized as:

  1. Introduce encoding operations in GlobalOptimization a. Encodings capture how a particular tensor is used. This allows materializing a concrete layout during codegen when we know the target explicitly. b. Currently MaterializeHomogenousEncodings will materialize these encodings as an explicit pack at the global optimization level to enable fusion. This breaks point (a) for multi-device.
  2. Encodings are materialized in codegen. This explicitly packs all of the operations in codegen to the layout specified by the encoding for that target. From here codegen carries on as normal. a. Currently this pass includes a step in the HAL conversion to materialize the encoded tensors on the host side.

Materializing encoded tensors on the host side won't work for GPU without having a forced dependency between codegen backends, which we should avoid propagating.

Hand picked cases of interest

This is a list of a few examples of operations we're interested in focusing on for encodings.

Matmul

The first obvious case is a simple matmul.

func.func @pack_gemm_fill_dynamic(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?xf32> {
  %0 = iree_linalg_ext.set_encoding %arg0 : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<role = LHS, element_types = [f32, f32, f32], user_indexing_maps = [#map, #map1, #map2]>>
  %1 = iree_linalg_ext.set_encoding %arg1 : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<role = RHS, element_types = [f32, f32, f32], user_indexing_maps = [#map, #map1, #map2]>>
  %2 = tensor.empty(%d0, %d1) : tensor<?x?xf32, #iree_linalg_ext.encoding<role = RESULT, element_types = [f32, f32, f32], user_indexing_maps = [#map, #map1, #map2]>>
  %3 = linalg.fill ins(%cst : f32) outs(%2 : tensor<?x?xf32, #iree_linalg_ext.encoding<role = RESULT, element_types = [f32, f32, f32], user_indexing_maps = [#map, #map1, #map2]>>)
      -> tensor<?x?xf32, #iree_linalg_ext.encoding<role = RESULT, element_types = [f32, f32, f32], user_indexing_maps = [#map, #map1, #map2]>>
  %4 = linalg.matmul ins(%0, %1 : tensor<?x?xf32, #iree_linalg_ext.encoding<role = LHS, element_types = [f32, f32, f32], user_indexing_maps = [#map, #map1, #map2]>>, tensor<?x?xf32, #iree_linalg_ext.encoding<role = RHS, element_types = [f32, f32, f32], user_indexing_maps = [#map, #map1, #map2]>>)
      outs(%3 : tensor<?x?xf32, #iree_linalg_ext.encoding<role = RESULT, element_types = [f32, f32, f32], user_indexing_maps = [#map, #map1, #map2]>>) -> tensor<?x?xf32, #iree_linalg_ext.encoding<role = RESULT, element_types = [f32, f32, f32], user_indexing_maps = [#map, #map1, #map2]>>
  %5 = iree_linalg_ext.unset_encoding %4 : tensor<?x?xf32, #iree_linalg_ext.encoding<role = RESULT, element_types = [f32, f32, f32], user_indexing_maps = [#map, #map1, #map2]>> -> tensor<?x?xf32>
  return %5 : tensor<?x?xf32>
}

For CPU backends, the set-encoding ops will turn into tensor.pack ops, the GEMM will turn into a linalg.mmt4d op, and the unset encoding op will turn into tensor.unpack. This is essentially packing to the size of the tile processing by a single iteration of the inner hot loop of the GEMM. For certain GPU architectures, we might want to go a step further to match an internal layout of a target intrinsic. Take MFMA for the CDNA3 architecture as an example (see section 7 of this architecture document). For the F16_16x16x16_F32 intrinsic, we might want three levels of packing

%rhs = tensor<?x?xf16>

// Pack to the tile of the RHS processed by a single iteration of the hot loop by one workgroup
%wkg_pack = tensor.pack %rhs inner_tile_sizes = [32, 64]

// Pack to the size of the intrinsic processed by a single subgroup
%sgp_pack = tensor.pack %wkg_pack inner_tile_sizes = [0, 0, 16, 16]

// Pack to the layout needed for a single MFMA instruction
%instr_pack = tensor.pack %sgp_pack inner_tile_sizes = [0, 0, 0, 0, 4, 16] inner_dims_pos = [6, 5]

For such cases, a single pack + mmt4d op would not work. This poses a problem for kernel config as the CPU approach of materialize encoding -> kernel config requires kernel config to relearn what the meaning of the multi-level packed generic means. Additionally, the approach to materialize packs at the GlobalOptimization level by way of MaterializeHomogenousEncodings makes this even more difficult as there's no guarantee that some later transformation doesn't mess up the packed generic (also we might want to materialize more than a single pack, which could cause further problems with fusion).

On GPU backends the cost of encoding ops is much more likely to outweigh the gains from setting encodings, making fusion of encoding setting operations a requirement for good performance, so materializing multi-level packs like this at the GlobalOptimization level is unlikely to work.

Quantized Matmul

We've seen quantization to be extremely important to model performance recently, in particular the fusion of dequantization (quantization) operations with consumers (producers) being extremely necessary for good performance.

        %33 = linalg.fill ins(%cst : f32) outs(%31 : tensor<1x?x4096xf32>) -> tensor<1x?x4096xf32>
        %34 = linalg.generic {
            indexing_maps = [
              affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
              affine_map<(d0, d1, d2) -> (d0, d1, 0)>,
              affine_map<(d0, d1, d2) -> (d0, d1, 0)>,
              affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
            iterator_types = ["parallel", "parallel", "parallel"]}
        ins(%27, %28, %29 : tensor<4096x32x128xi4>, tensor<4096x32x1xf32>, tensor<4096x32x1xf32>) outs(%32 : tensor<4096x32x128xf32>) {
        ^bb0(%in: i4, %in_0: f32, %in_1: f32, %out: f32):
          %36 = arith.extui %in : i4 to i32
          %37 = arith.uitofp %36 : i32 to f32
          %38 = arith.subf %37, %in_1 : f32
          %39 = arith.mulf %38, %in_0 : f32
          linalg.yield %39 : f32
        } -> tensor<4096x32x128xf32>
        %35 = linalg.generic {
            indexing_maps = [
              affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>,
              affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d4)>,
              affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>],
            iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]}
        ins(%30, %34 : tensor<1x?x32x128xf32>, tensor<4096x32x128xf32>) outs(%33 : tensor<1x?x4096xf32>) {
        ^bb0(%in: f32, %in_0: f32, %out: f32):
          %36 = arith.mulf %in, %in_0 : f32
          %37 = arith.addf %36, %out : f32
          linalg.yield %37 : f32
        } -> tensor<1x?x4096xf32>

To set encodings for the consumer matmul (%35), and actually benefit from the encodings, we must propagate the encoding to the dispatch boundary, and thus to the fused dequantization operation. The way that encodings work today, however, require padding to align the data to the target tile size. For a simple matmul, this works because the padding value is simply the invariant value for an FMA, which is zero. This does not work here though, because we would need to pick a constant padding value for the inputs to the dequantization operation, namely this math

          %36 = arith.extui %in : i4 to i32
          %37 = arith.uitofp %36 : i32 to f32
          %38 = arith.subf %37, %in_1 : f32
          %39 = arith.mulf %38, %in_0 : f32

Which, when taking into account the way that scales and zero points are broadcasted, is not possible. Importantly, however, we don't actually care doing the correct math to produce f32 = 0.0 from the dequantization, but rather:

1) That that layout of the quantized tensor in memory matches the desired packed layout 2) That the inputs to the FMA of the matmul are padded with 0.

This could instead be codegen'ed by masking the execution for the inner most tile (and not even bothering to initialize the padding of the pack). This could be done by adding some kind of "poison" bit to an encoding/pack to indicate that the final padded tile gives no guarantee as to the value of the out-of-bounds data, e.g.

tensor.pack %0 inner_tile_sizes = [4, 4] : tensor<3x3>
// Gives:
// [0, 1, 2, _]
// [3, 4, 5, _]
// [6, 7, 8, _]
// [_, _, _, _]

Where _ indicates uninitialized or arbitrarily filled data.

Fusion of pack with producers

As noted earlier, pack fusion is likely required to see significant performance gains on GPU backends given the often comparatively marginal gains from it for desktop/server grade GPUs, and also the additional overhead of allocating the memory for, and launching the encoding kernel. Fusing the pack with a consumer defeats the purpose of data tiling, so we need to be able to fuse a pack with consumers. Let's think through what that might look like for a case like this

  %4 = linalg.matmul ins(%0, %1 : tensor<?x?xf32, #iree_linalg_ext.encoding<role = LHS, element_types = [f32, f32, f32], user_indexing_maps = [#map, #map1, #map2]>>, tensor<?x?xf32, #iree_linalg_ext.encoding<role = RHS, element_types = [f32, f32, f32], user_indexing_maps = [#map, #map1, #map2]>>)
      outs(%3 : tensor<?x?xf32, #iree_linalg_ext.encoding<role = RESULT, element_types = [f32, f32, f32], user_indexing_maps = [#map, #map1, #map2]>>) -> tensor<?x?xf32, #iree_linalg_ext.encoding<role = RESULT, element_types = [f32, f32, f32], user_indexing_maps = [#map, #map1, #map2]>>
  %5 = linalg_ext.change_encoding %4 :
    tensor<?x?xf32, #iree_linalg_ext.encoding<role = RESULT, element_types = [f32, f32, f32], user_indexing_maps = [#map, #map1, #map2]>>
    to tensor<?x?xf32, #iree_linalg_ext.encoding<role = LHS, element_types = [f32, f32, f32], user_indexing_maps = [#map, #map1, #map2]>>
  %8 = linalg.matmul ins(%5, %6 : tensor<?x?xf32, #iree_linalg_ext.encoding<role = LHS, element_types = [f32, f32, f32], user_indexing_maps = [#map, #map1, #map2]>>, tensor<?x?xf32, #iree_linalg_ext.encoding<role = RHS, element_types = [f32, f32, f32], user_indexing_maps = [#map, #map1, #map2]>>)
      outs(%7 : tensor<?x?xf32, #iree_linalg_ext.encoding<role = RESULT, element_types = [f32, f32, f32], user_indexing_maps = [#map, #map1, #map2]>>) -> tensor<?x?xf32, #iree_linalg_ext.encoding<role = RESULT, element_types = [f32, f32, f32], user_indexing_maps = [#map, #map1, #map2]>>

In other words, we need to write back a different encoding from the producer matmul, which could include padding. There are three options for this fusion 1) Fuse the full set_encoding with the producer matmul. There are two ways to do this. a. Dedicate some threads of that matmul kernel to writing back the padding value. b. Use a few threads within subgroups writing data "near" the padded values to also write the padding as well as the value computed with that thread. 2) Split the set_encoding into an "insert-slice + transpose" and a separate smaller padding kernel.

For architectures like RDNA3 and CDNA1/2/3 which rely on subgroup operations to achieve good performance on matmuls, dedicating some threads of the matmul to writing the padding value is a significant mismatch in resource costs (dedicating threads of a resource intensive matmul kernel to just writing zeros is unlikely to perform well), so 1a is a bad option. Option 2 is more reasonable and composes nicely with graph based execution (e.g. HipGraphs/CUDA Graphs), however does require an extra kernel launch and eliding the extra kernel in cases where padding is not required. 1b) is probably the hardest to implement but makes the most sense, assuming the amount of padding is smaller than the number of threads launched (reasonable in most real world cases, except for completely arbitrary dynamic shapes that happen to end up being small at runtime, but those should be handled with specialization).

Attention

Attention internally does matrix operations and thus is another good target for data tiling. The current implementation of the attention operation, however, requires exactly 3d inputs and thus does not allow for setting encodings. One option could be to materialize encodings for attention after decomposing the operation, but the different options have not really been explored.

Convolution

Data tiling of convolution seems to primarily be a subset of what is required for matmul, as tiling the image dimensions of the convolution is not possible. TODO: Add more details here as necessary.

Task List

TODO

MaheshRavishankar commented 3 months ago

Thanks @qedawkins for the detailed issue! Nice read! There is a lot to unpack here, and I will read through this a few times, but here are some priors that for me clash with some points above

1) When introduced the set_encoding was more to capture the information of "how is this tensor used", i.e. is this the LHS of a matmul, or RHS of a matmul. This was generalized to handle linalg.generics using indexing maps. When we talk about "propagating" encodings, that seems like we are trying to capture "how is the user of the user of the tensor used". That seems like too much of an abstract information to hold in an encoding. I dont think "propagating of encoding" really works.

2) On the CPU side, at least the starting assumption for me was "even if we pay the overhead of the doing pack by itself, the benefit of having a better matmul kernel is worth the overhead". As stated above, that explicitly is not the starting assumption here. We need to make sure a) dequantize operation still fuses with its consumer, i.e. make sure the encoding "propagates" past the dequantize operation. b) make sure the pack fuses with its producers. All these HAVE to fall into place for this to work as expected. FWIW, we do have these happening already for the CPU side (the encoding propagation, and pack fusion with producers) but I am not sure how well it works.

All the above could be made to work and we can ensure those work as expected, but I think there is a difference on GPUs. We can use packing as a way to move data from global memory to shared memory (thats first level of packing you describe) and another pack to move data from shared memory to registers (We know we already have the transformations for this cause this is what we are using on the AIE backend and those work as expected, including working with bufferization). Not saying that strategy is most efficient (it is unclear how effective the use of shared memory this way will be, and how it affects overall performance to me). This approach will allow a) We can ensure the dequantize operation fuses with the matmul (much easier) b) There is no encoding to propagate, c) within a dispatch we can use tile and fuse + pack propagation to make sure the producer and consumers all see a consistent view of the data.

qedawkins commented 3 months ago

When introduced the set_encoding was more to capture the information of "how is this tensor used", i.e. is this the LHS of a matmul, or RHS of a matmul. This was generalized to handle linalg.generics using indexing maps. When we talk about "propagating" encodings, that seems like we are trying to capture "how is the user of the user of the tensor used". That seems like too much of an abstract information to hold in an encoding. I dont think "propagating of encoding" really works.

Makes sense, I was using "encoding" both as a general term for any attribute describing the layout of a tensor in memory, and as the concrete thing it is today. It would help to decouple those usages of the term, but I agree that in its current state, the propagation of an encoding is not possible.

All the above could be made to work and we can ensure those work as expected, but I think there is a difference on GPUs. We can use packing as a way to move data from global memory to shared memory (thats first level of packing you describe) and another pack to move data from shared memory to registers ... a) We can ensure the dequantize operation fuses with the matmul (much easier) b) There is no encoding to propagate, c) within a dispatch we can use tile and fuse + pack propagation to make sure the producer and consumers all see a consistent view of the data.

+1 to this approach, but I see this as orthogonal. What we do at rest without data tiling should be largely the same as what we do with data tiling, should we choose to push on it. I see data tiling as a "last 10%" kind of strategy (especially for the most powerful GPUs) rather that the absolute requirement that it is on CPU. I'm hoping that this issue can be more of a starting point for brainstorming on that last 10% rather than an excuse to not do the work in the default case.

c) within a dispatch we can use tile and fuse + pack propagation to make sure the producer and consumers all see a consistent view of the data.

Big +1 to this approach. We should discuss this idea more in a separate issue. It would be a big win to me if the codegen approaches for AIE and GPU end up overlapping a lot too :)

hanhanW commented 3 months ago

I took a look at different options in data-tiling path. The constant computation is still hoisted to initialization stage if we disable const-eval. However, it does not work if we defer the materialization to very end stage. [1] We can teach the compiler about the case, but there are other problems. We're facing the allocation issue. The iree_linalg_ext.upper_bound_tile_size ops are floating around stream.resource.alloca ops. The fact is that we always need a place to propagate tile sizes to be aligned with, so the host can allocate big enough buffers. See below snippet for example:

// -----// IR Dump Before CPUMaterializeUpperBoundTileSize (iree-codegen-cpu-materialize-upper-bound-tile-size) //----- //
  // ...
  %25:2 = iree_linalg_ext.upper_bound_tile_size tensor<?x?xf32, #iree_linalg_ext.encoding<role =  RESULT, element_types = [f32, f32, f32], user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>]>> -> index, index
  %26 = affine.apply affine_map<()[s0, s1] -> ((s1 ceildiv s0) * s0)>()[%25#0, %10]
  %27 = affine.apply affine_map<()[s0, s1] -> ((s1 ceildiv s0) * s0)>()[%25#1, %11]
  %28 = arith.muli %26, %c4 : index
  %29 = arith.muli %28, %27 : index
  %30 = util.align %19, %c64 : index
  %31 = util.align %24, %c64 : index
  %32 = arith.addi %30, %31 : index
  %33 = util.align %29, %c64 : index
  %34 = arith.addi %32, %33 : index
  %result_0, %result_timepoint_1 = stream.resource.alloca uninitialized : !stream.resource<transient>{%34} => !stream.timepoint

The workaround is running an early materialization pass (i.e., MaterializeHomogeneousEncodingsPass), which converts upper_bound_tile_size ops into constants (i.e., inner_tile_size) or indices computations. The end goal is getting rid of the pass. So I took another look at VMVX dynamic path; I found that we unconditionally make them be aligned with 16. 16 was chosen because it's the number in favor of most CPUs/GPUs [2]. Allocating much larger buffers is not a problem, because the kernels can just ignore it. This is how VMVX dynamic path works today. For the next steps, we can either [a] keep the assumption (i.e., all the sizes are aligned with 16), or [b] explore how to connect host and device properly + [1]. The tricky part in [b] is how to propagate tile sizes from device to host. I'm thinking if we can write tile sizes to some variables in initialization stage. If so, the device can initialize variables and host can query the upper_bound_tile_sizes from the variables in allocations.

https://github.com/openxla/iree/blob/41e312a26cefef299976f76c8882b1d894d5734b/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoPackUnPack.cpp#L281-L288

[1] So one of the works is to teach the compiler to hoist "set encodings on constants" to an initializer. [2] It's added long time ago, so I might be outdated.

benvanik commented 3 months ago

Rounding up for now is easiest. The real solution is going to be to use tensor encodings properly - stream is designed to support those even with heterogeneous devices (stream.tensor.sizeof, etc). We cannot have independent ops (like upper_bound_tile_size) in any world where there are heterogeneous devices - we must have tensor SSA values be able to have their size calculated based on the tensor (shape, element type, encoding) and the device that is using it (defined by the ops using it).

qedawkins commented 3 months ago

Rounding up for now is easiest. The real solution is going to be to use tensor encodings properly - stream is designed to support those even with heterogeneous devices (stream.tensor.sizeof, etc). We cannot have independent ops (like upper_bound_tile_size) in any world where there are heterogeneous devices - we must have tensor SSA values be able to have their size calculated based on the tensor (shape, element type, encoding) and the device that is using it (defined by the ops using it).

That is probably step 1 here then. After discussion offline, early materialization isn't going to work very well on GPU given that we want multi-level packs. If we have the chance, it would be nice to do it the right way.

benvanik commented 3 months ago

The first place to start sniffing around would be calculateStorageElementCountInBytes and making it use the encoding attr on the tensor type it's provided - the place where we compute sizes is EncodeTensors and that calls that routine to build the IR to get the size as needed. At that point we do know the device (well, once my multi-device stuff lands) and can emit target-specific queries if we need to. The important thing is that at that point we've placed all dispatches and know then where tensors are used so with the encoding indicating what a tensor is used for, the affinity indicating where it's used, and the tensor metadata (shape/element type) we should have everything we need to calculate byte sizes (even if not statically). So if we can move to encoding attrs of some kind via an interface (so we can change them over time/etc) we can avoid needing to have a perfect solution and instead allow tunneling from whatever is inserting the encodings all the way through to whatever is able to service the size requests without needing to change any other part of the system (like we have had to by inserting the upper bound ops/materializing homogeneous encodings during global opt/etc). I think we'll need a few iterations to get to the exact shape of things but to start even a #iree_codegen.encoding.round_innermost_dims_to<16> (or whatever) kind of thing would be fine and allow progress on the decoupling. After that we can emit device queries/etc.

What I'd do, if doing it all incrementally:

phase 1:

phase 2:

phase 3:

qedawkins commented 3 months ago

Thanks for the tips! I'll start picking this up after my break.

So if we can move to encoding attrs of some kind via an interface (so we can change them over time/etc)

This will be the key I think, because encoding attrs could be op specific (and will likely need to be revisited if we want to do any kind of propagation).

attach that attr interface to encoding attrs to produce stream.context.resolve ops (which take affinity and get you a !hal.device) and any runtime queries against the device you may want (if you even need it) - we can also add some ops to make that cleaner, but such an approach should work to start

My current thought process is that we will want to plug any kind of external tuning process into the tile sizes selected for an encoding. Going even a step further, on a target like SPIR-V where we have a JIT compiler + specialization constants, the tile sizes we pick for an encoding could be resolved + tuned directly on the vmfb. Or we could pass encodings as a push constant and we could provide some parameterized lookup table for resolving the encodings. Probably far away though, and maybe will never beat baking the tile sizes in at compile time (without another compiler backing it like SPIR-V).

benvanik commented 3 months ago

Yeah, having the baked out stuff will always be best for performance, then it's a tradeoff of compilation time/deployment size as to whether we bake out a lot of variants or leave things dynamic. Classic optimize-for-speed or optimize-for-size style stuff, and something we can use PGO for (knowing the actual values used for any executable/push constant will let us turn those into compile-time constants). If we write things using the queries to start it'll be push constants, but we have the option of going and turning those queries (for fixed devices) or related arithmetic (the arith on both the query result and dynamic shapes, etc) to constants and then propagating those into dispatches.

I'll also have a pass that hoists const exprs derived from device queries/arith into executable constant blocks which will become specialization constants in SPIR-V or just constant device-side buffers in CUDA/ROCM (effectively then just push constants but without the per-dispatch overhead of pushing them).

Fun things that are all unlocked by making some progress here on getting more IR and less hardcoding :)

qedawkins commented 3 months ago

Oooh nice, the constant buffer thing sounds cool. Yeah need to make progress here first.

hanhanW commented 3 months ago

Thanks for many details! Yes, phase 1 sounds good to me, and I need to think more about phase 2. I will start picking up the work, and make progress towards phase 2.

bjacob commented 3 months ago

Replying to this part of the original issue description above:

For the F16_16x16x16_F32 intrinsic, we might want three levels of packing

%rhs = tensor<?x?xf16>

// Pack to the tile of the RHS processed by a single iteration of the hot loop by one workgroup
%wkg_pack = tensor.pack %rhs inner_tile_sizes = [32, 64]

// Pack to the size of the intrinsic processed by a single subgroup
%sgp_pack = tensor.pack %wkg_pack inner_tile_sizes = [0, 0, 16, 16]

// Pack to the layout needed for a single MFMA instruction
%instr_pack = tensor.pack %sgp_pack inner_tile_sizes = [0, 0, 0, 0, 4, 16] inner_dims_pos = [6, 5]

For such cases, a single pack + mmt4d op would not work.

I still think that a single pack + mmt4d as in CPU should work:

  1. The first pack above, %wkg_pack = tensor.pack %rhs inner_tile_sizes = [32, 64], seems to try to pack a workgroup tile contiguously. I don't think that needs to be a goal. Some workgroup-level data locality is desirable, but it doesn't have to go all the way to requiring it to be completely contiguous. Just the next level of tile size ("intrinsic size") would give 0.5x the same locality, assuming 16x16 instrinsics and a row-major placement of these tiles, it would just make the workgroup tile split into 2 contiguous halves --- not too bad.
    • More generally, a theme is data-tiling design is: do not conflate traversal and layout. Data-tiling is about layout, doesn't need to be influenced by higher-level (e.g. workgroup) traversal considerations. With matrix multiplication, it is impossible anyway to have perfect access locality across the entire matmul, so you have to accept that there will be some jumps in the accesses at some point. That point needs to be high enough (you wouldn't want a jump at every inner loop iteration) but that's it.
  2. The second pack above, %sgp_pack, essentialy coalesces into the third, %instr_pack: the only layout difference it makes to do %sgp_pack as opposed to skipping straight to %instr_pack is padding of that dimension to the next multiple of 16. Indeed, here is the tile layout if we only do %instr_pack (annotating each 4x16 matrix entry with its offset, written in hex):
  0   4   8   c  10  14  18  1c  20  24  28  2c  30  34  38  3c 
  1   5   9   d  11  15  19  1d  21  25  29  2d  31  35  39  3d 
  2   6   a   e  12  16  1a  1e  22  26  2a  2e  32  36  3a  3e 
  3   7   b   f  13  17  1b  1f  23  27  2b  2f  33  37  3b  3f 

But this is just one tile in a larger tiled matrix so lets's carry on spelling out the layout of subsequent tiles. Each tile comes below the preceding tile and continues the offset increments from where it left off, so if we write the first 4 tiles, we get:

  0   4   8   c  10  14  18  1c  20  24  28  2c  30  34  38  3c 
  1   5   9   d  11  15  19  1d  21  25  29  2d  31  35  39  3d 
  2   6   a   e  12  16  1a  1e  22  26  2a  2e  32  36  3a  3e 
  3   7   b   f  13  17  1b  1f  23  27  2b  2f  33  37  3b  3f 
 40  44  48  4c  50  54  58  5c  60  64  68  6c  70  74  78  7c 
 41  45  49  4d  51  55  59  5d  61  65  69  6d  71  75  79  7d 
 42  46  4a  4e  52  56  5a  5e  62  66  6a  6e  72  76  7a  7e 
 43  47  4b  4f  53  57  5b  5f  63  67  6b  6f  73  77  7b  7f 
 80  84  88  8c  90  94  98  9c  a0  a4  a8  ac  b0  b4  b8  bc 
 81  85  89  8d  91  95  99  9d  a1  a5  a9  ad  b1  b5  b9  bd 
 82  86  8a  8e  92  96  9a  9e  a2  a6  aa  ae  b2  b6  ba  be 
 83  87  8b  8f  93  97  9b  9f  a3  a7  ab  af  b3  b7  bb  bf 
 c0  c4  c8  cc  d0  d4  d8  dc  e0  e4  e8  ec  f0  f4  f8  fc 
 c1  c5  c9  cd  d1  d5  d9  dd  e1  e5  e9  ed  f1  f5  f9  fd 
 c2  c6  ca  ce  d2  d6  da  de  e2  e6  ea  ee  f2  f6  fa  fe 
 c3  c7  cb  cf  d3  d7  db  df  e3  e7  eb  ef  f3  f7  fb  ff

So we see that as long as the number of rows is a multiple of 16, the %instr_pack alone produces the same layout as the %sgp_pack + %instr_pack.

So we can drop the %sgp_pack and instead pad the input to the next multiple of 16 rows.

This is not a coincidence. mmt4d / pack were designed this way because sane ISAs for matrix multiplication always tend to operate in this way, with data laid out for some small (here, 4-D) dot-products to consume contiguous vectors, resulting in this (row-major LHS, column-major RHS) layout where this coalescing happens when you look at multiple tiles. This results in the ability to combine multiple tiles into a larger tile without increasing the intrinsic dimensionality of the layout. This is basically the reason why this is the right choice for matrix multiplication ISAs.

hanhanW commented 3 months ago

The first place to start sniffing around would be calculateStorageElementCountInBytes and making it use the encoding attr on the tensor type it's provided - the place where we compute sizes is EncodeTensors and that calls that routine to build the IR to get the size as needed. At that point we do know the device (well, once my multi-device stuff lands) and can emit target-specific queries if we need to. The important thing is that at that point we've placed all dispatches and know then where tensors are used so with the encoding indicating what a tensor is used for, the affinity indicating where it's used, and the tensor metadata (shape/element type) we should have everything we need to calculate byte sizes (even if not statically). So if we can move to encoding attrs of some kind via an interface (so we can change them over time/etc) we can avoid needing to have a perfect solution and instead allow tunneling from whatever is inserting the encodings all the way through to whatever is able to service the size requests without needing to change any other part of the system (like we have had to by inserting the upper bound ops/materializing homogeneous encodings during global opt/etc). I think we'll need a few iterations to get to the exact shape of things but to start even a #iree_codegen.encoding.round_innermost_dims_to<16> (or whatever) kind of thing would be fine and allow progress on the decoupling. After that we can emit device queries/etc.

What I'd do, if doing it all incrementally:

phase 1:

  • start setting encoding attrs on tensors earlier on, define some in codegen/hal/upstream/wherever
  • hardcode calculateStorageElementCountInBytes to use those encoding attrs to do something simple (round up, etc)
  • delete the old passes/hacks/ops

I'm making some progress. In my prototype, I create SetEncodingHintOnDispatches pass which introduces #iree_codegen.encoding.round_dims_to<16> on dispatches; it removes the size calculation (i.e., iree_linalg_ext.upper_bound_tile_size -> affine_apply).

E.g.,

  %9:2 = iree_linalg_ext.upper_bound_tile_size tensor<?x?xf32, #iree_linalg_ext.encoding<role =  LHS, element_types = [f32, f32, f32], user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>]>> -> index, index
  %10 = affine.apply affine_map<()[s0, s1] -> ((s1 ceildiv s0) * s0)>()[%9#1, %1]
  %11 = affine.apply affine_map<()[s0, s1] -> ((s1 ceildiv s0) * s0)>()[%9#0, %0]
  %12 = flow.dispatch.workgroups[%9#1, %9#0, %0, %1, %11, %10](%9#1, %9#0, %2, %0, %1, %11, %10) : (index, index, tensor<?x?xf32>{%0, %1}, index, index, index, index) -> tensor<?x?xf32, #iree_linalg_ext.encoding<role =  LHS, element_types = [f32, f32, f32], user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>]>>{%11, %10} =
      (%arg3: index, %arg4: index, %arg5: !flow.dispatch.tensor<readonly:tensor<?x?xf32>>, %arg6: index, %arg7: index, %arg8: index, %arg9: index, %arg10: !flow.dispatch.tensor<writeonly:tensor<?x?xf32, #iree_linalg_ext.encoding<role =  LHS, element_types = [f32, f32, f32], user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>]>>>) {
    %cst = arith.constant 0.000000e+00 : f32
    %24 = flow.dispatch.workload.ordinal %arg6, 2 : index
    %25 = flow.dispatch.workload.ordinal %arg7, 3 : index
    %26 = flow.dispatch.workload.ordinal %arg8, 4 : index
    %27 = flow.dispatch.workload.ordinal %arg9, 5 : index
    %28 = flow.dispatch.tie_shape %arg5 : !flow.dispatch.tensor<readonly:tensor<?x?xf32>>{%24, %25}
    %29 = flow.dispatch.tie_shape %arg10 : !flow.dispatch.tensor<writeonly:tensor<?x?xf32, #iree_linalg_ext.encoding<role =  LHS, element_types = [f32, f32, f32], user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>]>>>{%26, %27}
    %30 = flow.dispatch.workload.ordinal %arg3, 0 : index
    %31 = flow.dispatch.workload.ordinal %arg4, 1 : index
    %32 = flow.dispatch.tensor.load %28, offsets = [0, 0], sizes = [%24, %25], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<?x?xf32>>{%24, %25} -> tensor<?x?xf32>
    %33 = affine.apply affine_map<()[s0, s1] -> (-s1 + (s1 ceildiv s0) * s0)>()[%30, %25]
    %34 = affine.apply affine_map<()[s0, s1] -> (-s1 + (s1 ceildiv s0) * s0)>()[%31, %24]
    %padded = tensor.pad %32 low[0, 0] high[%34, %33] {
    ^bb0(%arg11: index, %arg12: index):
      tensor.yield %cst : f32
    } : tensor<?x?xf32> to tensor<?x?xf32>
    %35 = iree_linalg_ext.set_encoding %padded : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<role =  LHS, element_types = [f32, f32, f32], user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>]>>
    flow.dispatch.tensor.store %35, %29, offsets = [0, 0], sizes = [%26, %27], strides = [1, 1] : tensor<?x?xf32, #iree_linalg_ext.encoding<role =  LHS, element_types = [f32, f32, f32], user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>]>> -> !flow.dispatch.tensor<writeonly:tensor<?x?xf32, #iree_linalg_ext.encoding<role =  LHS, element_types = [f32, f32, f32], user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>]>>>{%26, %27}
    flow.return
  } count(%arg3: index, %arg4: index, %arg5: index, %arg6: index, %arg7: index, %arg8: index) -> (index, index, index) {
    %x, %y, %z = flow.dispatch.workgroup_count_from_slice %arg3, %arg4, %arg5, %arg6, %arg7, %arg8
    flow.return %x, %y, %z : index, index, index
  }

is converted to

  %9 = flow.dispatch.workgroups[%c16, %0, %1](%c16, %2, %0, %1) : (index, tensor<?x?xf32>{%0, %1}, index, index) -> tensor<?x?xf32, #iree_linalg_ext.encoding<role =  LHS, element_types = [f32, f32, f32], user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>]>>{%0, %1}
    // -------------- This is the new attribute --------------- //
    attributes {encoding.round_dims_to = #iree_codegen.encoding.round_dims_to<16>} =
      (%arg3: index, %arg4: !flow.dispatch.tensor<readonly:tensor<?x?xf32>>, %arg5: index, %arg6: index, %arg7: !flow.dispatch.tensor<writeonly:tensor<?x?xf32, #iree_linalg_ext.encoding<role =  LHS, element_types = [f32, f32, f32], user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>]>>>) {
    %cst = arith.constant 0.000000e+00 : f32
    %15 = flow.dispatch.workload.ordinal %arg5, 1 : index
    %16 = flow.dispatch.workload.ordinal %arg6, 2 : index
    %17 = flow.dispatch.workload.ordinal %arg3, 0 : index
    %18 = flow.dispatch.tie_shape %arg4 : !flow.dispatch.tensor<readonly:tensor<?x?xf32>>{%15, %16}
    %19 = flow.dispatch.tie_shape %arg7 : !flow.dispatch.tensor<writeonly:tensor<?x?xf32, #iree_linalg_ext.encoding<role =  LHS, element_types = [f32, f32, f32], user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>]>>>{%15, %16}
    %20 = flow.dispatch.tensor.load %18, offsets = [0, 0], sizes = [%15, %16], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<?x?xf32>>{%15, %16} -> tensor<?x?xf32>
    %21 = affine.apply affine_map<()[s0, s1] -> (-s1 + (s1 ceildiv s0) * s0)>()[%17, %16]
    %22 = affine.apply affine_map<()[s0, s1] -> (-s1 + (s1 ceildiv s0) * s0)>()[%17, %15]
    %padded = tensor.pad %20 low[0, 0] high[%22, %21] {
    ^bb0(%arg8: index, %arg9: index):
      tensor.yield %cst : f32
    } : tensor<?x?xf32> to tensor<?x?xf32>
    %23 = iree_linalg_ext.set_encoding %padded : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<role =  LHS, element_types = [f32, f32, f32], user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>]>>
    flow.dispatch.tensor.store %23, %19, offsets = [0, 0], sizes = [%15, %16], strides = [1, 1] : tensor<?x?xf32, #iree_linalg_ext.encoding<role =  LHS, element_types = [f32, f32, f32], user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>]>> -> !flow.dispatch.tensor<writeonly:tensor<?x?xf32, #iree_linalg_ext.encoding<role =  LHS, element_types = [f32, f32, f32], user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>]>>>{%15, %16}
    flow.return
  } count(%arg3: index, %arg4: index, %arg5: index) -> (index, index, index) {
    %x, %y, %z = flow.dispatch.workgroup_count_from_slice %arg3, %arg4, %arg5
    flow.return %x, %y, %z : index, index, index
  }

(%c16 is created because the pass is added after all the dispatch passes. I can try to move the pass right before FormDispatchWorkgroups, so the flow.dispatch.workgroups does not need to capture the result of iree_linalg_ext.upper_bound_tile_size. I'll try that tomorrow.)

This allow us to propagate the information to backend; we can also use the new attribute in stream pipeline. In ConvertToStream pass, I teach the ConvertDispatchOp pattern to add the attribute on stream.tensor.sizeof op. E.g.,

// -----// IR Dump Before ConvertToStreamPass (iree-stream-conversion) //----- //
  util.func public @matmul_accumulate_DYNxDYNxf32_times_DYNxDYNxf32_into_DYNxDYNxf32(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub, iree.reflection = {iree.abi.declaration = "sync func @matmul_accumulate_DYNxDYNxf32_times_DYNxDYNxf32_into_DYNxDYNxf32(%input0: tensor<?x?xf32>, %input1: tensor<?x?xf32>, %input2: tensor<?x?xf32>) -> (%output0: tensor<?x?xf32>)"}} {
    %c16 = arith.constant 16 : index
    %0 = hal.buffer_view.dim<%arg0 : !hal.buffer_view>[0] : index
    %1 = hal.buffer_view.dim<%arg0 : !hal.buffer_view>[1] : index
    %2 = hal.tensor.import %arg0 "input0" : !hal.buffer_view -> tensor<?x?xf32>{%0, %1}
    %3 = hal.buffer_view.dim<%arg1 : !hal.buffer_view>[0] : index
    %4 = hal.buffer_view.dim<%arg1 : !hal.buffer_view>[1] : index
    %5 = hal.tensor.import %arg1 "input1" : !hal.buffer_view -> tensor<?x?xf32>{%3, %4}
    %6 = hal.buffer_view.dim<%arg2 : !hal.buffer_view>[0] : index
    %7 = hal.buffer_view.dim<%arg2 : !hal.buffer_view>[1] : index
    %8 = hal.tensor.import %arg2 "input2" : !hal.buffer_view -> tensor<?x?xf32>{%6, %7}

    // All the below dispatches have `{encoding.round_dims_to = #iree_codegen.encoding.round_dims_to<16>}`
    %9 = flow.dispatch @matmul_accumulate_DYNxDYNxf32_times_DYNxDYNxf32_into_DYNxDYNxf32_dispatch_0::@matmul_accumulate_DYNxDYNxf32_times_DYNxDYNxf32_into_DYNxDYNxf32_dispatch_0_set_encoding_LHS_DxD[%c16, %0, %1](%c16, %2, %0, %1) {encoding.round_dims_to = #iree_codegen.encoding.round_dims_to<16>} : (index, tensor<?x?xf32>{%0, %1}, index, index) -> tensor<?x?xf32, #iree_linalg_ext.encoding<role =  LHS, element_types = [f32, f32, f32], user_indexing_maps = [#map, #map1, #map2]>>{%0, %1}
    %10 = flow.dispatch @matmul_accumulate_DYNxDYNxf32_times_DYNxDYNxf32_into_DYNxDYNxf32_dispatch_1::@matmul_accumulate_DYNxDYNxf32_times_DYNxDYNxf32_into_DYNxDYNxf32_dispatch_1_set_encoding_RHS_DxD[%c16, %3, %4](%c16, %5, %3, %4) {encoding.round_dims_to = #iree_codegen.encoding.round_dims_to<16>} : (index, tensor<?x?xf32>{%3, %4}, index, index) -> tensor<?x?xf32, #iree_linalg_ext.encoding<role =  RHS, element_types = [f32, f32, f32], user_indexing_maps = [#map, #map1, #map2]>>{%3, %4}
    %11 = flow.dispatch @matmul_accumulate_DYNxDYNxf32_times_DYNxDYNxf32_into_DYNxDYNxf32_dispatch_2::@matmul_accumulate_DYNxDYNxf32_times_DYNxDYNxf32_into_DYNxDYNxf32_dispatch_2_set_encoding_RESULT_DxD[%c16, %6, %7](%c16, %8, %6, %7) {encoding.round_dims_to = #iree_codegen.encoding.round_dims_to<16>} : (index, tensor<?x?xf32>{%6, %7}, index, index) -> tensor<?x?xf32, #iree_linalg_ext.encoding<role =  RESULT, element_types = [f32, f32, f32], user_indexing_maps = [#map, #map1, #map2]>>{%6, %7}
    %12 = flow.dispatch @matmul_accumulate_DYNxDYNxf32_times_DYNxDYNxf32_into_DYNxDYNxf32_dispatch_3::@matmul_accumulate_DYNxDYNxf32_times_DYNxDYNxf32_into_DYNxDYNxf32_dispatch_3_matmul_DxDxD_f32[%0, %1, %3, %4, %6, %7](%9, %10, %11, %0, %1, %3, %4, %6, %7) {encoding.round_dims_to = #iree_codegen.encoding.round_dims_to<16>} : (tensor<?x?xf32, #iree_linalg_ext.encoding<role =  LHS, element_types = [f32, f32, f32], user_indexing_maps = [#map, #map1, #map2]>>{%0, %1}, tensor<?x?xf32, #iree_linalg_ext.encoding<role =  RHS, element_types = [f32, f32, f32], user_indexing_maps = [#map, #map1, #map2]>>{%3, %4}, tensor<?x?xf32, #iree_linalg_ext.encoding<role =  RESULT, element_types = [f32, f32, f32], user_indexing_maps = [#map, #map1, #map2]>>{%6, %7}, index, index, index, index, index, index) -> %11{%6, %7}
    %13 = flow.dispatch @matmul_accumulate_DYNxDYNxf32_times_DYNxDYNxf32_into_DYNxDYNxf32_dispatch_4::@matmul_accumulate_DYNxDYNxf32_times_DYNxDYNxf32_into_DYNxDYNxf32_dispatch_4_unset_encoding_RESULT_DxD[%6, %7](%12, %6, %7) {encoding.round_dims_to = #iree_codegen.encoding.round_dims_to<16>} : (tensor<?x?xf32, #iree_linalg_ext.encoding<role =  RESULT, element_types = [f32, f32, f32], user_indexing_maps = [#map, #map1, #map2]>>{%6, %7}, index, index) -> tensor<?x?xf32>{%6, %7}
    %14 = hal.tensor.export %13 "output0" : tensor<?x?xf32>{%6, %7} -> !hal.buffer_view
    util.return %14 : !hal.buffer_view
  }
// -----// IR Dump After ConvertToStreamPass (iree-stream-conversion) //----- //
  util.func public @matmul_accumulate_DYNxDYNxf32_times_DYNxDYNxf32_into_DYNxDYNxf32(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub, iree.reflection = {iree.abi.declaration = "sync func @matmul_accumulate_DYNxDYNxf32_times_DYNxDYNxf32_into_DYNxDYNxf32(%input0: tensor<?x?xf32>, %input1: tensor<?x?xf32>, %input2: tensor<?x?xf32>) -> (%output0: tensor<?x?xf32>)"}} {
    %c16 = arith.constant 16 : index
    %0 = hal.buffer_view.dim<%arg0 : !hal.buffer_view>[0] : index
    %1 = hal.buffer_view.dim<%arg0 : !hal.buffer_view>[1] : index
    %element_type_f32 = hal.element_type<f32> : i32
    %dense_row_major = hal.encoding_type<dense_row_major> : i32
    hal.buffer_view.assert<%arg0 : !hal.buffer_view> message("input0") shape([%0, %1]) type(%element_type_f32) encoding(%dense_row_major)
    %2 = stream.tensor.sizeof tensor<?x?xf32>{%0, %1} : index
    %3 = stream.tensor.import %arg0 : !hal.buffer_view -> tensor<?x?xf32>{%0, %1} in !stream.resource<external>{%2}
    %4 = stream.async.transfer %3 : !stream.resource<external>{%2} -> !stream.resource<*>{%2}
    %5 = hal.buffer_view.dim<%arg1 : !hal.buffer_view>[0] : index
    %6 = hal.buffer_view.dim<%arg1 : !hal.buffer_view>[1] : index
    %element_type_f32_0 = hal.element_type<f32> : i32
    %dense_row_major_1 = hal.encoding_type<dense_row_major> : i32
    hal.buffer_view.assert<%arg1 : !hal.buffer_view> message("input1") shape([%5, %6]) type(%element_type_f32_0) encoding(%dense_row_major_1)
    %7 = stream.tensor.sizeof tensor<?x?xf32>{%5, %6} : index
    %8 = stream.tensor.import %arg1 : !hal.buffer_view -> tensor<?x?xf32>{%5, %6} in !stream.resource<external>{%7}
    %9 = stream.async.transfer %8 : !stream.resource<external>{%7} -> !stream.resource<*>{%7}
    %10 = hal.buffer_view.dim<%arg2 : !hal.buffer_view>[0] : index
    %11 = hal.buffer_view.dim<%arg2 : !hal.buffer_view>[1] : index
    %element_type_f32_2 = hal.element_type<f32> : i32
    %dense_row_major_3 = hal.encoding_type<dense_row_major> : i32
    hal.buffer_view.assert<%arg2 : !hal.buffer_view> message("input2") shape([%10, %11]) type(%element_type_f32_2) encoding(%dense_row_major_3)
    %12 = stream.tensor.sizeof tensor<?x?xf32>{%10, %11} : index
    %13 = stream.tensor.import %arg2 : !hal.buffer_view -> tensor<?x?xf32>{%10, %11} in !stream.resource<external>{%12}
    %14 = stream.async.transfer %13 : !stream.resource<external>{%12} -> !stream.resource<*>{%12}
    %c0 = arith.constant 0 : index
    // note: the below sizeof ops have {encoding.round_dims_to = #iree_codegen.encoding.round_dims_to<16>}.
    %15 = stream.tensor.sizeof tensor<?x?xf32, #iree_linalg_ext.encoding<role =  LHS, element_types = [f32, f32, f32], user_indexing_maps = [#map, #map1, #map2]>>{%0, %1} {encoding.round_dims_to = #iree_codegen.encoding.round_dims_to<16>} : index
    %16 = stream.async.dispatch @matmul_accumulate_DYNxDYNxf32_times_DYNxDYNxf32_into_DYNxDYNxf32_dispatch_0::@matmul_accumulate_DYNxDYNxf32_times_DYNxDYNxf32_into_DYNxDYNxf32_dispatch_0_set_encoding_LHS_DxD[%c16, %0, %1](%c16, %4[%c0 to %2 for %2], %0, %1) {encoding.round_dims_to = #iree_codegen.encoding.round_dims_to<16>} : (index, !stream.resource<*>{%2}, index, index) -> !stream.resource<*>{%15}
    %c0_4 = arith.constant 0 : index
    %17 = stream.tensor.sizeof tensor<?x?xf32, #iree_linalg_ext.encoding<role =  RHS, element_types = [f32, f32, f32], user_indexing_maps = [#map, #map1, #map2]>>{%5, %6} {encoding.round_dims_to = #iree_codegen.encoding.round_dims_to<16>} : index
    %18 = stream.async.dispatch @matmul_accumulate_DYNxDYNxf32_times_DYNxDYNxf32_into_DYNxDYNxf32_dispatch_1::@matmul_accumulate_DYNxDYNxf32_times_DYNxDYNxf32_into_DYNxDYNxf32_dispatch_1_set_encoding_RHS_DxD[%c16, %5, %6](%c16, %9[%c0_4 to %7 for %7], %5, %6) {encoding.round_dims_to = #iree_codegen.encoding.round_dims_to<16>} : (index, !stream.resource<*>{%7}, index, index) -> !stream.resource<*>{%17}
    %c0_5 = arith.constant 0 : index
    %19 = stream.tensor.sizeof tensor<?x?xf32, #iree_linalg_ext.encoding<role =  RESULT, element_types = [f32, f32, f32], user_indexing_maps = [#map, #map1, #map2]>>{%10, %11} {encoding.round_dims_to = #iree_codegen.encoding.round_dims_to<16>} : index
    %20 = stream.async.dispatch @matmul_accumulate_DYNxDYNxf32_times_DYNxDYNxf32_into_DYNxDYNxf32_dispatch_2::@matmul_accumulate_DYNxDYNxf32_times_DYNxDYNxf32_into_DYNxDYNxf32_dispatch_2_set_encoding_RESULT_DxD[%c16, %10, %11](%c16, %14[%c0_5 to %12 for %12], %10, %11) {encoding.round_dims_to = #iree_codegen.encoding.round_dims_to<16>} : (index, !stream.resource<*>{%12}, index, index) -> !stream.resource<*>{%19}
    %c0_6 = arith.constant 0 : index
    %21 = stream.async.dispatch @matmul_accumulate_DYNxDYNxf32_times_DYNxDYNxf32_into_DYNxDYNxf32_dispatch_3::@matmul_accumulate_DYNxDYNxf32_times_DYNxDYNxf32_into_DYNxDYNxf32_dispatch_3_matmul_DxDxD_f32[%0, %1, %5, %6, %10, %11](%16[%c0_6 to %15 for %15], %18[%c0_6 to %17 for %17], %20[%c0_6 to %19 for %19], %0, %1, %5, %6, %10, %11) {encoding.round_dims_to = #iree_codegen.encoding.round_dims_to<16>} : (!stream.resource<*>{%15}, !stream.resource<*>{%17}, !stream.resource<*>{%19}, index, index, index, index, index, index) -> %20{%19}
    %c0_7 = arith.constant 0 : index
    %22 = stream.tensor.sizeof tensor<?x?xf32>{%10, %11} {encoding.round_dims_to = #iree_codegen.encoding.round_dims_to<16>} : index
    %23 = stream.async.dispatch @matmul_accumulate_DYNxDYNxf32_times_DYNxDYNxf32_into_DYNxDYNxf32_dispatch_4::@matmul_accumulate_DYNxDYNxf32_times_DYNxDYNxf32_into_DYNxDYNxf32_dispatch_4_unset_encoding_RESULT_DxD[%10, %11](%21[%c0_7 to %19 for %19], %10, %11) {encoding.round_dims_to = #iree_codegen.encoding.round_dims_to<16>} : (!stream.resource<*>{%19}, index, index) -> !stream.resource<*>{%22}
    %24 = stream.async.transfer %23 : !stream.resource<*>{%22} -> !stream.resource<external>{%22}
    %25 = stream.tensor.export %24 : tensor<?x?xf32>{%10, %11} in !stream.resource<external>{%22} -> !hal.buffer_view
    util.return %25 : !hal.buffer_view
  }

Now I can start teaching calculateStorageElementCountInBytes to use the encoding attrs to do roundup calculation. Then we can entirely remove CPUMaterializeUpperBoundTileSize pass from HAL pipeline. @benvanik @MaheshRavishankar does the layer look okay? We can look at more IR together on a VC if we need more information.

(The selected IR dumps can be found at https://gist.github.com/hanhanW/8d4f77c6903ca773c6f60098b5e541b1)

benvanik commented 3 months ago

Nice progress on decoupling!

I'll need to look more tomorrow, but I don't think the layering is quite right here - we don't want encoding attributes on ops but the encoding attribute on tensors (there's literally an encoding attr). Free-floating attrs on ops get into inconsistent state, don't properly handle different encodings per operand/result, and don't compose with ops that aren't aware of the semantics.

We also need to decouple the exact padding and any per-dispatch information from the head of the pipeline - what we're trying to introduce is an aspect where something "may be" padded but not explicitly specify the padding. That's easiest if we can avoid needing to know the padding for as long as possible. We just need to know something may have padding, and then we can ask what that padding is once we know where the tensor is used (in EncodeHostTensors once affinities are assigned).

(in the multi-device world, all stream ops now have affinities assigned so ops are like stream.tensor.sizeof on(#hal.device.affinity<@some_device>) ... - EncodeHostTensors then has both the tensor type (with its encoding attr, if any) and the device on which it is associated - in this first version, though, we can just have EncodeHostTensors fallback to 16 as a hardcoded value)

benvanik commented 3 months ago

(I'll have to think more in the morning about the heterogeneous cases and how we resolve those, but it's likely to be an SSA value that lets us get the encoding of a tensor and pass that around - so instead of tensor.dim there'd be stream.tensor.encoding that returns some opaque !stream.tensor.encoding we can pass around - then stream.tensor.sizeof could optionally take that and use it instead of the encoding attr on the tensor, letting us move encodings across devices or something.... complexity for later once things are decoupled)

hanhanW commented 3 months ago

we don't want encoding attributes on ops but the encoding attribute on tensors (there's literally an encoding attr).

This is very tricky with what we have today. Encoding is an attribute but not an op. We want to somehow have an encoding showing in the IR.

IR without my changes after converting to stream:

  %15:2 = iree_linalg_ext.upper_bound_tile_size tensor<?x?xf32, #iree_linalg_ext.encoding<role =  LHS, element_types = [f32, f32, f32], user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>]>> -> index, index
  %16 = affine.apply affine_map<()[s0, s1] -> ((s1 ceildiv s0) * s0)>()[%15#1, %1]
  %17 = affine.apply affine_map<()[s0, s1] -> ((s1 ceildiv s0) * s0)>()[%15#0, %0]
  %c0 = arith.constant 0 : index
  %18 = stream.tensor.sizeof tensor<?x?xf32, #iree_linalg_ext.encoding<role =  LHS, element_types = [f32, f32, f32], user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>]>>{%17, %16} : index

We dropped the size calculation logics with my changes, IR dump:

%15 = stream.tensor.sizeof
  tensor<?x?xf32, #iree_linalg_ext.encoding<role =  LHS, element_types = [f32, f32, f32], user_indexing_maps = [#map, #map1, #map2]>>{%0, %1}
  {encoding.round_dims_to = #iree_codegen.encoding.round_dims_to<16>} : index

We can only have one encoding on tensor types today, so I need to attach the information as an attribute. Do you think we should teach #iree_linalg_ext.encoding to carry the round_dims_to information? In this context, we won't have the floating #iree_codegen.encoding.round_dims_to<16> attribute. The caveat is that there are more fields in the encoding though.

benvanik commented 3 months ago

you can either teach it to support it (which I think is the easiest/best option - move it out of iree_linalg_ext and into util or something else and let's standardize it for ourselves) or teach the rounding encoding to nest - so #iree_codegen.encoding.round_dims_to<16, #iree_linalg_ext.encoding<....>>. I think us having one encoding attr we consistently use and can extend as we want is useful and going to make things much easier to modify going forward.

hanhanW commented 3 months ago

I'll try to teach the #iree_linalg_ext.encoding to carry the information. I also think that it is the easiest/better option.

hanhanW commented 3 months ago

I tried a couple things and I realized that we can just set the round_dims_to field in SetEncoding pass. Otherwise, I need to deal with all type conversion things.. The new encoding looks like

#iree_linalg_ext.encoding<
  role =  LHS,
  element_types = [f32, f32, f32],
  original_type = tensor<?x?xf32>,
  user_indexing_maps = [#map, #map1, #map2],
  // Below is the new parameter! It can be a list because `role` and `maps` provide enough information
  round_dims_to = 16 : index>

The IR after SetEncoding could be cleaner -- no iree_linalg_ext.upper_bound_tile_size and tensor.pad anymore:

  util.func public @matmul_accumulate_DYNxDYNxf32_times_DYNxDYNxf32_into_DYNxDYNxf32(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub, iree.reflection = {iree.abi.declaration = "sync func @matmul_accumulate_DYNxDYNxf32_times_DYNxDYNxf32_into_DYNxDYNxf32(%input0: tensor<?x?xf32>, %input1: tensor<?x?xf32>, %input2: tensor<?x?xf32>) -> (%output0: tensor<?x?xf32>)"}} {
    %c1 = arith.constant 1 : index
    %c0 = arith.constant 0 : index
    %0 = hal.buffer_view.dim<%arg0 : !hal.buffer_view>[0] : index
    %1 = hal.buffer_view.dim<%arg0 : !hal.buffer_view>[1] : index
    %2 = hal.tensor.import %arg0 "input0" : !hal.buffer_view -> tensor<?x?xf32>{%0, %1}
    %3 = hal.buffer_view.dim<%arg1 : !hal.buffer_view>[0] : index
    %4 = hal.buffer_view.dim<%arg1 : !hal.buffer_view>[1] : index
    %5 = hal.tensor.import %arg1 "input1" : !hal.buffer_view -> tensor<?x?xf32>{%3, %4}
    %6 = hal.buffer_view.dim<%arg2 : !hal.buffer_view>[0] : index
    %7 = hal.buffer_view.dim<%arg2 : !hal.buffer_view>[1] : index
    %8 = hal.tensor.import %arg2 "input2" : !hal.buffer_view -> tensor<?x?xf32>{%6, %7}
    %9 = iree_linalg_ext.set_encoding %2 : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<role =  LHS, element_types = [f32, f32, f32], original_type = tensor<?x?xf32>, user_indexing_maps = [#map, #map1, #map2], round_dims_to = 16 : index>>
    %10 = iree_linalg_ext.set_encoding %5 : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<role =  RHS, element_types = [f32, f32, f32], original_type = tensor<?x?xf32>, user_indexing_maps = [#map, #map1, #map2], round_dims_to = 16 : index>>
    %11 = iree_linalg_ext.set_encoding %8 : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<role =  RESULT, element_types = [f32, f32, f32], original_type = tensor<?x?xf32>, user_indexing_maps = [#map, #map1, #map2], round_dims_to = 16 : index>>
    %12 = linalg.matmul ins(%9, %10 : tensor<?x?xf32, #iree_linalg_ext.encoding<role =  LHS, element_types = [f32, f32, f32], original_type = tensor<?x?xf32>, user_indexing_maps = [#map, #map1, #map2], round_dims_to = 16 : index>>, tensor<?x?xf32, #iree_linalg_ext.encoding<role =  RHS, element_types = [f32, f32, f32], original_type = tensor<?x?xf32>, user_indexing_maps = [#map, #map1, #map2], round_dims_to = 16 : index>>) outs(%11 : tensor<?x?xf32, #iree_linalg_ext.encoding<role =  RESULT, element_types = [f32, f32, f32], original_type = tensor<?x?xf32>, user_indexing_maps = [#map, #map1, #map2], round_dims_to = 16 : index>>) -> tensor<?x?xf32, #iree_linalg_ext.encoding<role =  RESULT, element_types = [f32, f32, f32], original_type = tensor<?x?xf32>, user_indexing_maps = [#map, #map1, #map2], round_dims_to = 16 : index>>
    %dim = tensor.dim %8, %c0 : tensor<?x?xf32>
    %dim_0 = tensor.dim %8, %c1 : tensor<?x?xf32>
    %13 = iree_linalg_ext.unset_encoding %12 : tensor<?x?xf32, #iree_linalg_ext.encoding<role =  RESULT, element_types = [f32, f32, f32], original_type = tensor<?x?xf32>, user_indexing_maps = [#map, #map1, #map2], round_dims_to = 16 : index>> -> tensor<?x?xf32>
    %extracted_slice = tensor.extract_slice %13[0, 0] [%dim, %dim_0] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
    %14 = hal.tensor.export %extracted_slice "output0" : tensor<?x?xf32>{%6, %7} -> !hal.buffer_view
    util.return %14 : !hal.buffer_view
  }

On the stream side, we have stream.tensor.sizeof on a tensor with such encoding. So the calculateStorageElementCountInBytes can compute the actual size based on the encoding.

%15 = stream.tensor.sizeof tensor<?x?xf32, #iree_linalg_ext.encoding<role =  LHS, element_types = [f32, f32, f32], original_type = tensor<?x?xf32>, user_indexing_maps = [#map, #map1, #map2], round_dims_to = 16 : index>>{%0, %1} : index

On codegen side, we need to teach the materialization patterns to add padding values to the pack op, which should be easy to me. I'm going to

  1. Put all the pieces together
  2. Run some tests to make sure I don't miss anything
  3. Slice PRs out for review.
benvanik commented 3 months ago

Great! That'll be a good improvement and get us to a place where doing the more sophisticated stuff is just for optimization - I suspect we'll change how things are done so that we can refine this based on device placement to only pad as much as required by the devices particular tensors are produced/consumed on, but that's much easier to reason about when starting from this point.

Thanks for puzzling through this :)

hanhanW commented 3 months ago

I have a chain of PR that makes this happen (https://github.com/openxla/iree/pull/17055), which depends on some refactoring PRs (i.e., https://github.com/openxla/iree/pull/17040 and https://github.com/openxla/iree/pull/17053). In the prototype:

  1. I introduced an integer array (i.e., round_dims_to) field to the encoding. https://github.com/openxla/iree/pull/17055/commits/5f2d212df0f846491b3a1c50fd2fba907c54c7ad
  2. On codegen side, we'd check if the materialization is valid. I.e., all the inner tile sizes are less than or equal to the padding hint. https://github.com/openxla/iree/pull/17055/commits/9b045bb5557ce40e95d284319823f13185b866eb
  3. Then we teach the calculateStorageElementCountInBytes to compute sizes based on round_dims_to, so we can resolve stream.tensor.sizeof with encodings in EncodeHostTensor pass. https://github.com/openxla/iree/pull/17055/commits/6c129244f1d5ec128ece683e88ebc0df4ef3833d

Some PRs are out for review. If people want to take a look at further changes, see https://github.com/openxla/iree/pull/17055

However, I hit a cyclic dep issue when doing the refactoring, which happens between LinalgExt/IR and LinalgExt/Utils. I think eventually we want to move encoding attribute and set/unset_encodings to util dialect. So I plan to do the migration first, and then slice more PRs out from the prototype.

hanhanW commented 3 months ago

Here is the lit tests about stream.tensor.sizeof in the prototype, putting it here for visibility:

#map = affine_map<(d0, d1, d2) -> (d0, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
util.func public @sizeoflhsencoding(%arg0: index, %arg1: index) -> index {
  %0 = stream.tensor.sizeof tensor<?x?xf32, #iree_linalg_ext.encoding<role = LHS, element_types = [f32, f32, f32], original_type = tensor<?x?xf32>, user_indexing_maps = [#map, #map1, #map2], round_dims_to = 16, 16, 16>>{%arg0, %arg1} : index
  util.return %0 : index
}
// CHECK-LABEL: @sizeoflhsencoding
// CHECK-DAG:     %[[C4:.+]] = arith.constant 4 : index
// CHECK-DAG:     %[[C16:.+]] = arith.constant 16 : index
// CHECK:         %[[CEIL_DIV_D0:.+]] = arith.ceildivui %arg0, %[[C16]]
// CHECK:         %[[PAD_D0:.+]] = arith.muli %[[CEIL_DIV_D0]], %[[C16]]
// CHECK:         %[[CEIL_DIV_D1:.+]] = arith.ceildivui %arg1, %[[C16]]
// CHECK:         %[[PAD_D1:.+]] = arith.muli %[[CEIL_DIV_D1]], %[[C16]]
// CHECK:         %[[T0:.+]] = arith.muli %[[PAD_D0]], %[[C4]]
// CHECK:         %[[T1:.+]] = arith.muli %[[T0]], %[[PAD_D1]]
// CHECK:         return %[[T1]]