Closed hanhanW closed 2 years ago
Could you post the IR with alloca
as well. (and yes we need to fix this issue).
Oops, I forgot to attach it.
func @dot_33x16x49_exp_dispatch_0() {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f32
%c33 = arith.constant 33 : index
%c49 = arith.constant 49 : index
%c8 = arith.constant 8 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(32) : memref<33x16xf32>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(32) : memref<16x49xf32>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(32) : memref<33x49xf32>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%workgroup_count_x = hal.interface.workgroup.count[0] : index
%workgroup_id_y = hal.interface.workgroup.id[1] : index
%workgroup_count_y = hal.interface.workgroup.count[1] : index
%3 = affine.apply affine_map<()[s0] -> (s0 * 16)>()[%workgroup_id_y]
%4 = affine.apply affine_map<()[s0] -> (s0 * 16)>()[%workgroup_count_y]
%5 = affine.apply affine_map<()[s0] -> (s0 * 16)>()[%workgroup_id_x]
%6 = affine.apply affine_map<()[s0] -> (s0 * 16)>()[%workgroup_count_x]
scf.for %arg0 = %3 to %c33 step %4 {
%7 = affine.min affine_map<(d0) -> (16, -d0 + 33)>(%arg0)
%8 = affine.min affine_map<(d0) -> (-d0 + 33, 16)>(%arg0)
%9 = memref.subview %0[%arg0, 0] [%8, 16] [1, 1] : memref<33x16xf32> to memref<?x16xf32, affine_map<(d0, d1)[s0] -> (d0 * 16 + s0 + d1)>>
scf.for %arg1 = %5 to %c49 step %6 {
%10 = affine.min affine_map<(d0) -> (16, -d0 + 49)>(%arg1)
%11 = memref.subview %2[%arg0, %arg1] [%7, %10] [1, 1] : memref<33x49xf32> to memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 49 + s0 + d1)>>
%12 = affine.min affine_map<(d0) -> (-d0 + 49, 16)>(%arg1)
%13 = memref.subview %1[0, %arg1] [16, %12] [1, 1] : memref<16x49xf32> to memref<16x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 49 + s0 + d1)>>
%14 = memref.subview %13[0, 0] [16, %12] [1, 1] : memref<16x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 49 + s0 + d1)>> to memref<16x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 49 + s0 + d1)>>
scf.for %arg2 = %c0 to %7 step %c8 {
%15 = affine.min affine_map<(d0, d1, d2) -> (-d0 + d1, 8, -d0 + d2)>(%arg2, %8, %7)
%16 = affine.min affine_map<(d0, d1, d2) -> (-d0 + d1, -d0 + d2, 8)>(%arg2, %7, %8)
%17 = memref.alloca(%16, %10) {alignment = 128 : i64} : memref<?x?xf32>
%18 = memref.subview %11[%arg2, 0] [%16, %10] [1, 1] : memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 49 + s0 + d1)>> to memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 49 + s0 + d1)>>
linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%18 : memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 49 + s0 + d1)>>) outs(%17 : memref<?x?xf32>) {
^bb0(%arg3: f32, %arg4: f32):
linalg.yield %arg3 : f32
}
linalg.fill(%cst, %17) : f32, memref<?x?xf32>
%19 = memref.subview %9[%arg2, 0] [%15, 16] [1, 1] : memref<?x16xf32, affine_map<(d0, d1)[s0] -> (d0 * 16 + s0 + d1)>> to memref<?x16xf32, affine_map<(d0, d1)[s0] -> (d0 * 16 + s0 + d1)>>
%20 = memref.subview %17[0, 0] [%16, %10] [1, 1] : memref<?x?xf32> to memref<?x?xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>
linalg.matmul {lowering.config = #iree_codegen.lowering.config<tile_sizes = [[], [8, 0, 0], [0, 0, 16]], native_vector_size = []>} ins(%19, %14 : memref<?x16xf32, affine_map<(d0, d1)[s0] -> (d0 * 16 + s0 + d1)>>, memref<16x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 49 + s0 + d1)>>) outs(%20 : memref<?x?xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>)
linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%20 : memref<?x?xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>) outs(%20 : memref<?x?xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>) {
^bb0(%arg3: f32, %arg4: f32):
linalg.yield %arg3 : f32
}
%21 = affine.min affine_map<(d0, d1) -> (8, -d0 + d1)>(%arg2, %7)
linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%17 : memref<?x?xf32>) outs(%17 : memref<?x?xf32>) {
^bb0(%arg3: f32, %arg4: f32):
%23 = math.exp %arg3 : f32
linalg.yield %23 : f32
}
%22 = memref.subview %11[%arg2, 0] [%21, %10] [1, 1] : memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 49 + s0 + d1)>> to memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 49 + s0 + d1)>>
linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%17 : memref<?x?xf32>) outs(%22 : memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 49 + s0 + d1)>>) {
^bb0(%arg3: f32, %arg4: f32):
linalg.yield %arg3 : f32
}
}
linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%11 : memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 49 + s0 + d1)>>) outs(%11 : memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 49 + s0 + d1)>>) {
^bb0(%arg2: f32, %arg3: f32):
linalg.yield %arg2 : f32
}
}
}
return
}
Note: this also happens in IREE old bufferization, but I think it is still a good thing to fix.
This is likely caused by the linalg.init_tensor
preprocessing in IREE.
This line:
%2 = linalg.init_tensor [33, 49] : tensor<33x49xf32>
is replaced by:
%18 = tensor.extract_slice %11[%arg2, 0] [%17, %10] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
When it should really be replaced by:
%18 = tensor.extract_slice %arg3[%arg2, 0] [%17, %10] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
I'll take a closer look tomorrow.
@matthias-springer did you get a chance to take a closer look at this issue?
This is likely caused by the
linalg.init_tensor
preprocessing in IREE.This line:
%2 = linalg.init_tensor [33, 49] : tensor<33x49xf32>
is replaced by:
%18 = tensor.extract_slice %11[%arg2, 0] [%17, %10] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
When it should really be replaced by:
%18 = tensor.extract_slice %arg3[%arg2, 0] [%17, %10] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
I'll take a closer look tomorrow.
AFAIK, IREEs preprocessing does not insert tensor.extract_slice
. It only inserts flow.dispatch.tensor.load
. I think this tensor.extract_slice
is coming from fusion.
AFAIK, IREEs preprocessing does not insert tensor.extract_slice. It only inserts flow.dispatch.tensor.load. I think this tensor.extract_slice is coming from fusion.
You're right. I think we looked at a similar case before.
There's multiple issues at play here:
First, if you are inside of a loop, you should not use the original iter_arg
(%11
) but the corresponding bbArg (%arg3
). Otherwise, this is usually considered a RaW conflict by the bufferization and causes an alloc+copy. The reason is that to the bufferization it looks like you're reading the old value of the buffer. When in fact, you're doing an extract_slice and only read the part that was not modified yet. But the bufferization is not smart enough to figure this out. We would need a more powerful analysis. At the moment we only consider "buffer is read" or "buffer is written", but not "this tiny part of the buffer read" etc. That by itself is already a larger extension to the bufferization (which we should definitely do at some point). But then there are also those affine.min
that make it harder to figure out what part of the buffer is actually read. We would probably have to combine bufferization with FlatAffineConstraints
.
Is there a way to use %arg3
instead of %11
during fusion? This would likely be the easier fix. (Or if somebody wants to extend the bufferization, I'd be happy to help, but I can't do it by myself at the moment.)
The second issue is that even if you change to %arg3
, that tensor would be extracted from and written to multiple times. First by the linalg.fill
, then by the tensor.insert_slice
at then end of the loop. One of those two requires a buffer copy. Is it possible to generate code such that you extract from %arg3
once, do something with it, then insert the result back into %arg3
? That's the pattern that bufferization work with best. (Note: Extracting from %11
doesn't make it better, because %11
and %arg3
are essentially the same thing.)
Can you generate something like this and pass it to the bufferization? (Not 100% sure if this is doing the same computation.) Just showing the inner loop here:
%15 = scf.for %arg2 = %c0 to %7 step %c8 iter_args(%arg3 = %11) -> (tensor<?x?xf32>) {
%16 = affine.min affine_map<(d0, d1, d2) -> (-d0 + d1, 8, -d0 + d2)>(%arg2, %8, %7)
%17 = affine.min affine_map<(d0, d1, d2) -> (-d0 + d1, -d0 + d2, 8)>(%arg2, %7, %8)
%18 = tensor.extract_slice %arg3[%arg2, 0] [%17, %10] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
%19 = linalg.fill(%cst, %18) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%20 = tensor.extract_slice %9[%arg2, 0] [%16, 16] [1, 1] : tensor<?x16xf32> to tensor<?x16xf32>
%21 = tensor.extract_slice %19[0, 0] [%17, %10] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
%22 = linalg.matmul {lowering.config = #iree_codegen.lowering.config<tile_sizes = [[], [8, 0, 0], [0, 0, 16]], native_vector_size = []>} ins(%20, %14 : tensor<?x16xf32>, tensor<16x?xf32>) outs(%21 : tensor<?x?xf32>) -> tensor<?x?xf32>
%23 = tensor.insert_slice %22 into %19[0, 0] [%17, %10] [1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
%26 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins() outs(%23 : tensor<?x?xf32>) {
^bb0(%arg4: f32):
%28 = math.exp %arg4 : f32
linalg.yield %28 : f32
} -> tensor<?x?xf32>
%27 = tensor.insert_slice %26 into %arg3[%arg2, 0] [%17, %10] [1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
scf.yield %27 : tensor<?x?xf32>
}
If you can use %arg3
instead of %11
, we may be able to simplify the IR and get to something like the above IR with a rewrite pattern that takes a look of what tensors are extracted/inserted where etc.
Yes, the extract_slice ops are generated by fusion. https://gist.githubusercontent.com/hanhanW/3e1bb8d3ecfe627c0a505ebc929242c5/raw
The pass generates the pattern:
%17 = scf.for %arg2 = %c0 to %15 step %c8 iter_args(%arg3 = %12) -> (tensor<?x?xf32>) {
%18 = affine.min affine_map<(d0, d1, d2) -> (-d0 + d1, 8, -d0 + d2)>(%arg2, %10, %15)
%19 = tensor.extract_slice %9[%arg2, 0] [%18, 16] [1, 1] : tensor<?x16xf32> to tensor<?x16xf32>
%20 = affine.min affine_map<(d0, d1, d2, d3) -> (-d0 + d1, -d0 + d2, 8, -d0 + d3)>(%arg2, %15, %10, %15)
%21 = tensor.extract_slice %12[%arg2, 0] [%20, %16] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
%22 = linalg.fill(%cst, %21) {__internal_linalg_transform__ = "1"} : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%23 = linalg.matmul {__internal_linalg_transform__ = "1", lowering.config = #iree_codegen.lowering.config<tile_sizes = [[], [8, 0, 0], [0, 0, 16]], native_vector_size = []>} ins(%19, %14 : tensor<?x16xf32>, tensor<16x?xf32>) outs(%22 : tensor<?x?xf32>) -> tensor<?x?xf32>
We are not able to reproduce it with Sandbox because the Sandbox version has an extra temp input buffer for the result of matmul. @gysit Would you like to take a look at this since you have more context about linalg::fuseProducerOfTensor
? In the meantime, I can try to fix Sandbox ver to have this input IR.
Wouldnt the fusion example in sandbox allow you to reproduce the issue.
No, it wouldn't. Hopefully I have a quick fix for doing it in Sandbox. With https://github.com/google/iree-llvm-sandbox/pull/290 and the following snippet, we are able to reproduce the issue.
def fill_matmul_bias_add_fusion():
fun_name = 'matmul_bias_add'
op_name = 'linalg.matmul'
expert = Fuse(fun_name, 'linalg.generic', tile_sizes=[8, 16, 0]).then(
Tile(
fun_name,
'linalg.matmul',
tile_sizes=[0, 0, 24],
# pad=True,
# pack_paddings=[1, 1, 0],
)).then(Vectorize(fun_name, '', vectorize_paddings=True)).then(
Bufferize()).then(LowerVectors()).then(LowerToLLVM())
keys = ['M', 'N', 'K']
n_iters = 1
problem_size_list = [[27, 37, 43]]
test_harness(lambda s, t: MatmulBiasAddProblem(), [[np.float32] * 4],
test_sizes(keys, problem_size_list),
[expert.print_ir(after_all=True, at_begin=False, llvm=False)],
n_iters=n_iters,
function_name=fun_name,
zero_at_each_iteration=True)
# CHECK-NOT: FAILURE
def main():
# fill_matmul_fusion()
fill_matmul_bias_add_fusion()
It generates the IR after Fuse:
%0 = scf.for %arg4 = %c0 to %c27 step %c8 iter_args(%arg5 = %arg3) -> (tensor<27x37xf32>) {
%1 = affine.min affine_map<(d0) -> (-d0 + 27, 8)>(%arg4)
%2 = tensor.extract_slice %arg0[%arg4, 0] [%1, 43] [1, 1] : tensor<27x43xf32> to tensor<?x43xf32>
%3 = affine.min affine_map<(d0) -> (8, -d0 + 27)>(%arg4)
%4 = scf.for %arg6 = %c0 to %c37 step %c16 iter_args(%arg7 = %arg5) -> (tensor<27x37xf32>) {
%5 = affine.min affine_map<(d0) -> (-d0 + 37, 16)>(%arg6)
%6 = tensor.extract_slice %arg1[0, %arg6] [43, %5] [1, 1] : tensor<43x37xf32> to tensor<43x?xf32>
%7 = tensor.extract_slice %arg3[%arg4, %arg6] [%1, %5] [1, 1] : tensor<27x37xf32> to tensor<?x?xf32>
%8 = linalg.fill(%cst, %7) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%9 = linalg.matmul ins(%2, %6 : tensor<?x43xf32>, tensor<43x?xf32>) outs(%8 : tensor<?x?xf32>) -> tensor<?x?xf32>
...
where
%7 = tensor.extract_slice %arg3[%arg4, %arg6] [%1, %5] [1, 1] : tensor<27x37xf32> to tensor<?x?xf32>
should be replaced with
%7 = tensor.extract_slice %arg7[%arg4, %arg6] [%1, %5] [1, 1] : tensor<27x37xf32> to tensor<?x?xf32>
@matthias-springer and me had a look at the example and implemented an alternative way to express the bias computation that enables in-place bufferization https://github.com/google/iree-llvm-sandbox/pull/292.
tileAndFuse
determines the output arguments based on the root operation that is tiled first. In the example, the root operation is the generic implementing the bias. Its output is made an iteration argument. If we want to make the outputs of fill and matmul an iteration argument the easiest option is to pass the matmul result to the output parameter of the bias generic. The pull request implements this change. I am unsure if this solution is an option in the context of IREE but it is the cleanest solution from my point of view.
At some point we may want to extend fusion to add additional iteration arguments for outputs of fused producers. This is relevant if a temporary value of the fused loop nest is used outside of the tile loops by a different consumer. The change is not trivial and it may be better to introduce a fake generic operation at the end of the computation to control what values should be iteration arguments.
Getting the input to bufferization in this form in the presence of fusion is not going to be easy. At this point on this path, the entire code is generated by using the codegen strategy. So its a matter of what the capabilities of fusion + in-place bufferization are to. AFAICS the input to the fusion is
%init = linalg.init_tensor [%d0, %d1] : tensor<?x?xf32>
%fill = linalg.fill(%c0, %init) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%gemm = linalg.matmul ins(%lhs, %rhs : tensor<?x?xf32>, tensor<?x?xf32>) outs(%fill : tensor<?x?xf32>) -> tensor<?x?xf32>
%bias_add = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>
iterator_types = ["parallel", "parallel"]}
ins(%gemm, %bias : tensor<?x?xf32>, tensor<?xf32>) outs(%init : tensor<?x?xf32>) {
^bb0(%b0 : f32, %b1 : f32, %b2 : f32) :
%0 = arith.addf %b0, %b1 : f32
linalg.yield %0 : f32
} -> tensor<?x?xf32>
The overall question is if we can tile + fuse + bufferize this without additional stack usage. Whatever transformations are needed to do that must be part of what is within codegen driver. I can write the final memref
IR where that is the case, we might be missing transformations that get us there, but all of that seems like it should be within the scope of codegen driver itself.
Thanks for the clarifications. I think the example is indeed interesting since it shows the dependency between fusion and bufferization. @hanhanW and me discussed this in a bit more detail. I think we may want some pre bufferization pass before fusion to transform the IR in destination passing style. Another option may be to avoid the stack allocation by vectorizing fill and matmul before bufferizing them.
I think that works today. The problem comes in cases where they dont vectorize because of dynamic dimensions. On those paths, it would be good to be able to bufferization without stack allocations. I know sandbox is exploring padding and peeling options to always vectorize. Those done work with fusion (at least not to the extent they need to be for what is seen in IREE). So they arent directly usable yet in IREE.
Agreed we may not always be able to rely on vectorization. I think for fusion there are two options then. 1) Transform the IR in advance and make all operands that shall bufferize in place outputs, or 2) add some sort of annotation analysis that tells fusion and tiling that a specific operand needs to be treated like an output (i.e. do not fuse in reduction loops and add iteration arguments).
Follow-ups from today's codegen meeting. I have a WIP commit that adapts Linalg ops inputs to outputs. I verified that the allocations disappear with the commit.
https://github.com/google/iree/commit/f47aa9f14463154e7eb9cc2de18fd45142bd32e6
The commit does fix the allocation issue in the attached IR (i.e, fill -> matmul -> exp). But it does not work well in matmul_bias_add
. I verified that we have the same pattern as @gysit shown in https://github.com/google/iree-llvm-sandbox/pull/292, but we still get stack allocation. The sandbox does not create alloca op, but the IREE's one does. Looks like there are other issues in IREEComprehensiveBufferize.
This is the input IR:
module {
func @matmul_bias_add(%arg0: tensor<27x43xf32>, %arg1: tensor<43x37xf32>, %arg2: tensor<37xf32>, %arg3: tensor<27x37xf32>) -> tensor<27x37xf32> {
%cst = arith.constant 0.000000e+00 : f32
%0 = linalg.fill(%cst, %arg3) : f32, tensor<27x37xf32> -> tensor<27x37xf32>
%1 = linalg.matmul ins(%arg0, %arg1 : tensor<27x43xf32>, tensor<43x37xf32>) outs(%0 : tensor<27x37xf32>) -> tensor<27x37xf32>
%2 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> (d1)>,
affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]}
ins(%1, %arg2 : tensor<27x37xf32>, tensor<37xf32>)
outs(%arg3 : tensor<27x37xf32>) {
^bb0(%arg4: f32, %arg5: f32, %arg6: f32):
%3 = arith.addf %arg4, %arg5 : f32
linalg.yield %3 : f32
} -> tensor<27x37xf32>
return %2 : tensor<27x37xf32>
}
}
The following is the IR after destination passing style (before fusion) with the patch. The generic op takes the result of matmul as one of outs
.
// -----// IR Dump After ConvertToDestinationPassingStyle //----- //
func @matmul_bias_add_dispatch_0() {
%cst = arith.constant 0.000000e+00 : f32
%c0 = arith.constant 0 : index
%c27 = arith.constant 27 : index
%c37 = arith.constant 37 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(32) : !flow.dispatch.tensor<readonly:27x37xf32>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(32) : !flow.dispatch.tensor<readonly:27x43xf32>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(32) : !flow.dispatch.tensor<readonly:43x37xf32>
%3 = hal.interface.binding.subspan set(0) binding(3) type(storage_buffer) offset(%c0) alignment(32) : !flow.dispatch.tensor<readonly:37xf32>
%4 = hal.interface.binding.subspan set(0) binding(4) type(storage_buffer) offset(%c0) alignment(32) : !flow.dispatch.tensor<writeonly:27x37xf32>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%workgroup_count_x = hal.interface.workgroup.count[0] : index
%workgroup_id_y = hal.interface.workgroup.id[1] : index
%workgroup_count_y = hal.interface.workgroup.count[1] : index
%5 = affine.apply affine_map<()[s0] -> (s0 * 16)>()[%workgroup_id_y]
%6 = affine.apply affine_map<()[s0] -> (s0 * 16)>()[%workgroup_count_y]
scf.for %arg0 = %5 to %c27 step %6 {
%7 = affine.apply affine_map<()[s0] -> (s0 * 16)>()[%workgroup_id_x]
%8 = affine.apply affine_map<()[s0] -> (s0 * 16)>()[%workgroup_count_x]
scf.for %arg1 = %7 to %c37 step %8 {
%9 = affine.min affine_map<(d0) -> (16, -d0 + 37)>(%arg1)
%10 = flow.dispatch.tensor.load %3, offsets = [%arg1], sizes = [%9], strides = [1] : !flow.dispatch.tensor<readonly:37xf32> -> tensor<?xf32>
%11 = affine.min affine_map<(d0) -> (16, -d0 + 27)>(%arg0)
%12 = affine.min affine_map<(d0) -> (-d0 + 27, 16)>(%arg0)
%13 = flow.dispatch.tensor.load %1, offsets = [%arg0, 0], sizes = [%12, 43], strides = [1, 1] : !flow.dispatch.tensor<readonly:27x43xf32> -> tensor<?x43xf32>
%14 = affine.min affine_map<(d0) -> (-d0 + 37, 16)>(%arg1)
%15 = flow.dispatch.tensor.load %2, offsets = [0, %arg1], sizes = [43, %14], strides = [1, 1] : !flow.dispatch.tensor<readonly:43x37xf32> -> tensor<43x?xf32>
%16 = flow.dispatch.tensor.load %0, offsets = [%arg0, %arg1], sizes = [%12, %14], strides = [1, 1] : !flow.dispatch.tensor<readonly:27x37xf32> -> tensor<?x?xf32>
%17 = linalg.fill(%cst, %16) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%18 = linalg.matmul {lowering.config = #iree_codegen.lowering.config<tile_sizes = [[], [8, 0, 0], [0, 0, 16]], native_vector_size = []>} ins(%13, %15 : tensor<?x43xf32>, tensor<43x?xf32>) outs(%17 : tensor<?x?xf32>) -> tensor<?x?xf32>
%19 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]}
ins(%10 : tensor<?xf32>)
outs(%18 : tensor<?x?xf32>) {
^bb0(%arg2: f32, %arg3: f32):
%20 = arith.addf %arg3, %arg2 : f32
linalg.yield %20 : f32
} -> tensor<?x?xf32>
flow.dispatch.tensor.store %19, %4, offsets = [%arg0, %arg1], sizes = [%11, %9], strides = [1, 1] : tensor<?x?xf32> -> !flow.dispatch.tensor<writeonly:27x37xf32>
}
}
return
}
IR before bufferization:
// -----// IR Dump After CSE //----- //
func @matmul_bias_add_dispatch_0() {
%c1 = arith.constant 1 : index
%c43 = arith.constant 43 : index
%c16 = arith.constant 16 : index
%c8 = arith.constant 8 : index
%c37 = arith.constant 37 : index
%c27 = arith.constant 27 : index
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f32
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(32) : !flow.dispatch.tensor<readonly:27x37xf32>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(32) : !flow.dispatch.tensor<readonly:27x43xf32>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(32) : !flow.dispatch.tensor<readonly:43x37xf32>
%3 = hal.interface.binding.subspan set(0) binding(3) type(storage_buffer) offset(%c0) alignment(32) : !flow.dispatch.tensor<readonly:37xf32>
%4 = hal.interface.binding.subspan set(0) binding(4) type(storage_buffer) offset(%c0) alignment(32) : !flow.dispatch.tensor<writeonly:27x37xf32>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%workgroup_count_x = hal.interface.workgroup.count[0] : index
%workgroup_id_y = hal.interface.workgroup.id[1] : index
%workgroup_count_y = hal.interface.workgroup.count[1] : index
%5 = affine.apply affine_map<()[s0] -> (s0 * 16)>()[%workgroup_id_y]
%6 = affine.apply affine_map<()[s0] -> (s0 * 16)>()[%workgroup_count_y]
%7 = affine.apply affine_map<()[s0] -> (s0 * 16)>()[%workgroup_id_x]
%8 = affine.apply affine_map<()[s0] -> (s0 * 16)>()[%workgroup_count_x]
scf.for %arg0 = %5 to %c27 step %6 {
%9 = affine.min affine_map<(d0) -> (16, -d0 + 27)>(%arg0)
%10 = affine.min affine_map<(d0) -> (-d0 + 27, 16)>(%arg0)
%11 = flow.dispatch.tensor.load %1, offsets = [%arg0, 0], sizes = [%10, 43], strides = [1, 1] : !flow.dispatch.tensor<readonly:27x43xf32> -> tensor<?x43xf32>
scf.for %arg1 = %7 to %c37 step %8 {
%12 = affine.min affine_map<(d0) -> (16, -d0 + 37)>(%arg1)
%13 = flow.dispatch.tensor.load %3, offsets = [%arg1], sizes = [%12], strides = [1] : !flow.dispatch.tensor<readonly:37xf32> -> tensor<?xf32>
%14 = affine.min affine_map<(d0) -> (-d0 + 37, 16)>(%arg1)
%15 = flow.dispatch.tensor.load %2, offsets = [0, %arg1], sizes = [43, %14], strides = [1, 1] : !flow.dispatch.tensor<readonly:43x37xf32> -> tensor<43x?xf32>
%16 = flow.dispatch.tensor.load %0, offsets = [%arg0, %arg1], sizes = [%10, %14], strides = [1, 1] : !flow.dispatch.tensor<readonly:27x37xf32> -> tensor<?x?xf32>
%17 = scf.for %arg2 = %c0 to %10 step %c8 iter_args(%arg3 = %16) -> (tensor<?x?xf32>) {
%18 = affine.min affine_map<(d0, d1) -> (8, -d0 + d1)>(%arg2, %10)
%19 = tensor.extract_slice %arg3[%arg2, 0] [%18, %14] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
%20 = linalg.fill(%cst, %19) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%21 = affine.min affine_map<(d0, d1) -> (-d0 + d1, 8)>(%arg2, %10)
%22 = tensor.dim %20, %c0 : tensor<?x?xf32>
%23 = tensor.dim %20, %c1 : tensor<?x?xf32>
%24 = scf.for %arg4 = %c0 to %c43 step %c16 iter_args(%arg5 = %20) -> (tensor<?x?xf32>) {
%27 = affine.min affine_map<(d0) -> (16, -d0 + 43)>(%arg4)
%28 = tensor.extract_slice %11[%arg2, %arg4] [%21, %27] [1, 1] : tensor<?x43xf32> to tensor<?x?xf32>
%29 = tensor.extract_slice %15[%arg4, 0] [%27, %14] [1, 1] : tensor<43x?xf32> to tensor<?x?xf32>
%30 = tensor.extract_slice %arg5[0, 0] [%22, %23] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
%31 = linalg.matmul {lowering.config = #iree_codegen.lowering.config<tile_sizes = [[], [8, 0, 0], [0, 0, 16]], native_vector_size = []>} ins(%28, %29 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%30 : tensor<?x?xf32>) -> tensor<?x?xf32>
%32 = tensor.insert_slice %31 into %arg5[0, 0] [%22, %23] [1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
scf.yield %32 : tensor<?x?xf32>
}
%25 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%13 : tensor<?xf32>) outs(%24 : tensor<?x?xf32>) {
^bb0(%arg4: f32, %arg5: f32):
%27 = arith.addf %arg5, %arg4 : f32
linalg.yield %27 : f32
} -> tensor<?x?xf32>
%26 = tensor.insert_slice %25 into %arg3[%arg2, 0] [%18, %14] [1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
scf.yield %26 : tensor<?x?xf32>
}
flow.dispatch.tensor.store %17, %4, offsets = [%arg0, %arg1], sizes = [%9, %12], strides = [1, 1] : tensor<?x?xf32> -> !flow.dispatch.tensor<writeonly:27x37xf32>
}
}
return
}
IR after bufferization: https://gist.githubusercontent.com/hanhanW/c303b7dfde4d2d14da826ba443da39ce/raw
IR after bufferizaion and some cleanup:
// -----// IR Dump After CleanupBufferAllocView //----- //
func @matmul_bias_add_dispatch_0() {
%c43 = arith.constant 43 : index
%c16 = arith.constant 16 : index
%c8 = arith.constant 8 : index
%c37 = arith.constant 37 : index
%c27 = arith.constant 27 : index
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f32
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(32) : memref<27x37xf32>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(32) : memref<27x43xf32>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(32) : memref<43x37xf32>
%3 = hal.interface.binding.subspan set(0) binding(3) type(storage_buffer) offset(%c0) alignment(32) : memref<37xf32>
%4 = hal.interface.binding.subspan set(0) binding(4) type(storage_buffer) offset(%c0) alignment(32) : memref<27x37xf32>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%workgroup_count_x = hal.interface.workgroup.count[0] : index
%workgroup_id_y = hal.interface.workgroup.id[1] : index
%workgroup_count_y = hal.interface.workgroup.count[1] : index
%5 = affine.apply affine_map<()[s0] -> (s0 * 16)>()[%workgroup_id_y]
%6 = affine.apply affine_map<()[s0] -> (s0 * 16)>()[%workgroup_count_y]
%7 = affine.apply affine_map<()[s0] -> (s0 * 16)>()[%workgroup_id_x]
%8 = affine.apply affine_map<()[s0] -> (s0 * 16)>()[%workgroup_count_x]
scf.for %arg0 = %5 to %c27 step %6 {
%9 = affine.min affine_map<(d0) -> (16, -d0 + 27)>(%arg0)
%10 = affine.min affine_map<(d0) -> (-d0 + 27, 16)>(%arg0)
%11 = memref.subview %1[%arg0, 0] [%10, 43] [1, 1] : memref<27x43xf32> to memref<?x43xf32, affine_map<(d0, d1)[s0] -> (d0 * 43 + s0 + d1)>>
scf.for %arg1 = %7 to %c37 step %8 {
%12 = affine.min affine_map<(d0) -> (16, -d0 + 37)>(%arg1)
%13 = memref.subview %3[%arg1] [%12] [1] : memref<37xf32> to memref<?xf32, affine_map<(d0)[s0] -> (d0 + s0)>>
%14 = affine.min affine_map<(d0) -> (-d0 + 37, 16)>(%arg1)
%15 = memref.subview %2[0, %arg1] [43, %14] [1, 1] : memref<43x37xf32> to memref<43x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 37 + s0 + d1)>>
%16 = memref.subview %0[%arg0, %arg1] [%10, %14] [1, 1] : memref<27x37xf32> to memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 37 + s0 + d1)>>
%17 = memref.alloca(%10, %14) {alignment = 128 : i64} : memref<?x?xf32>
linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%16 : memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 37 + s0 + d1)>>) outs(%17 : memref<?x?xf32>) {
^bb0(%arg2: f32, %arg3: f32):
linalg.yield %arg2 : f32
}
scf.for %arg2 = %c0 to %10 step %c8 {
%19 = affine.min affine_map<(d0, d1) -> (8, -d0 + d1)>(%arg2, %10)
%20 = memref.subview %17[%arg2, 0] [%19, %14] [1, 1] : memref<?x?xf32> to memref<?x?xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>
linalg.fill(%cst, %20) : f32, memref<?x?xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>
%21 = affine.min affine_map<(d0, d1) -> (-d0 + d1, 8)>(%arg2, %10)
scf.for %arg3 = %c0 to %c43 step %c16 {
%22 = affine.min affine_map<(d0) -> (16, -d0 + 43)>(%arg3)
%23 = memref.subview %11[%arg2, %arg3] [%21, %22] [1, 1] : memref<?x43xf32, affine_map<(d0, d1)[s0] -> (d0 * 43 + s0 + d1)>> to memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 43 + s0 + d1)>>
%24 = memref.subview %15[%arg3, 0] [%22, %14] [1, 1] : memref<43x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 37 + s0 + d1)>> to memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 37 + s0 + d1)>>
%25 = memref.subview %20[0, 0] [%19, %14] [1, 1] : memref<?x?xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>> to memref<?x?xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>
linalg.matmul {lowering.config = #iree_codegen.lowering.config<tile_sizes = [[], [8, 0, 0], [0, 0, 16]], native_vector_size = []>} ins(%23, %24 : memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 43 + s0 + d1)>>, memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 37 + s0 + d1)>>) outs(%25 : memref<?x?xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>)
}
linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%13 : memref<?xf32, affine_map<(d0)[s0] -> (d0 + s0)>>) outs(%20 : memref<?x?xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>) {
^bb0(%arg3: f32, %arg4: f32):
%22 = arith.addf %arg4, %arg3 : f32
linalg.yield %22 : f32
}
}
%18 = memref.subview %4[%arg0, %arg1] [%9, %12] [1, 1] : memref<27x37xf32> to memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 37 + s0 + d1)>>
linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%17 : memref<?x?xf32>) outs(%18 : memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 37 + s0 + d1)>>) {
^bb0(%arg2: f32, %arg3: f32):
linalg.yield %arg2 : f32
}
}
}
return
}
A memref.alloca op is created before going into computations (i.e., the first scf loop after distribution). @matthias-springer I think the IR before bufferization is the form you're looking for. Can you help look into it?
The problem here is that you are loading from / storing into different tensors:
%16 = flow.dispatch.tensor.load %0, offsets = [%arg0, %arg1], sizes = [%10, %14], strides = [1, 1] : !flow.dispatch.tensor<readonly:27x37xf32> -> tensor<?x?xf32>
%17 = scf.for %arg2 = %c0 to %10 step %c8 iter_args(%arg3 = %16) -> (tensor<?x?xf32>) {
...
}
flow.dispatch.tensor.store %17, %4, offsets = [%arg0, %arg1], sizes = [%9, %12], strides = [1, 1] : tensor<?x?xf32> -> !flow.dispatch.tensor<writeonly:27x37xf32>
If you store the result into %0
at the same offsets/sizes/strides and make it a flow.dispatch.tensor<readwrite:...>
, then there should be no copy.
@hanhanW I didnt catch it during review, but the generic op should drop the operand all together I think. That op has weird semantics cause it is all parallel, but is reading/writing the outs which is a dependence. Its better to not do that and just change the outs operand to be the result of the matmul.
Also the input IR in the post above confuses me. THere is no init_tensor
there, so its hard for me to see how destination style passing can be achieved there.
@hanhanW I didnt catch it during review, but the generic op should drop the operand all together I think. That op has weird semantics cause it is all parallel, but is reading/writing the outs which is a dependence. Its better to not do that and just change the outs operand to be the result of the matmul.
If I understand correctly, this may lead to problems with fusion since the generic would consume the matmul result twice (once as an input and once as an output). Fusion will then fuse twice and stuff will not cse since we once fuse via iteration argument of the for loop... An easy fix may be that fusion should never fuse output operands that have no use in the body of the operation (note for named operations this may not be obvious).
Changing the outs operand leads to having problems in fusion. There are two matmul ops after fusion. We might the fix from fusion side?
for named operations this may not be obvious
I think there is a interface method to get block argument and we check something like op.getRegionOutputArgs()[0].use_empty()
.
Regarding another allocation issue, I think we can't simply modify it to readwrite
. Thats an ABI change. I think the input is not a general use case to IREE. We don't fill values into input tensor. Instead, we should fill values into an init_tensor. It's my mistake that copies IRs from sandbox without checking if they are legal to IREE. After the modification on IRs, we don't see the allocation. Everything works as expected.
The changed IR:
module {
func @matmul_bias_add(%arg0: tensor<27x43xf32>, %arg1: tensor<43x37xf32>, %arg2: tensor<37xf32>, %arg3: tensor<27x37xf32>) -> tensor<27x37xf32> {
%cst = arith.constant 0.000000e+00 : f32
%init = linalg.init_tensor [27, 37] : tensor<27x37xf32>
%0 = linalg.fill(%cst, %init) : f32, tensor<27x37xf32> -> tensor<27x37xf32>
%1 = linalg.matmul ins(%arg0, %arg1 : tensor<27x43xf32>, tensor<43x37xf32>) outs(%0 : tensor<27x37xf32>) -> tensor<27x37xf32>
%2 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> (d1)>,
affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]}
ins(%1, %arg2 : tensor<27x37xf32>, tensor<37xf32>)
outs(%arg3 : tensor<27x37xf32>) {
^bb0(%arg4: f32, %arg5: f32, %arg6: f32):
%3 = arith.addf %arg4, %arg5 : f32
linalg.yield %3 : f32
} -> tensor<27x37xf32>
return %2 : tensor<27x37xf32>
}
}
In the IR above the output of the generic would be %1 as well right? Meaning the generic would have two operands set to %1?
And yes you can look at the uses of the region argument. At the moment, all operations also the named ones have a body it is just hidden. I think not fusing outputs that have no use may be a useful change.
https://reviews.llvm.org/D120981 should stop fusing shape-only operands. Can you try if this helps? I would expect the original matmul outside of the tile loops remains alive since we still access its shape. Meaning follow up steps may be needed. Additionally, I am also a bit unsure if bufferization will actually work.
Not fusing shape-only tensors still seems like a good thing to do...
An alternative fusion patch would be to check for every input operand if it is an output as well. In these cases, we would then fuse the producer via iteration argument... Overall, I do not have a clear understanding yet what is really needed to solve the general problem. I think rewiring the generic output to an input is more of a hack to circumvent the problem for the current example and not really a solution for the general problem.
Not fusing shape-only tensors is a good thing in general. I dont think reusing the input for the output is necessarily a hack. Its the way of the input code to tell the bufferization to use the same memory for the result as one of the input operands. This is on par with what was discussed in a meeting about breaking def-use DAGs into def-use chains as a pre-processing to bufferization. Issue now is where to do this. Right now this is being done early cause that is the easiest in terms of analysis of def-use DAGs, instead of trying to do this after say fusion with more complex code paths.
I think not fusing operands that are used for shape only makes sense in general, and should help here hopefully.
I agree that not fusing shape-only tensors makes sense anyways. I am just unsure if this will be sufficient on its own. Let me know if https://reviews.llvm.org/D120981 helps!
Yes, I think the patch makes sense generally. I'll give it a shot and let you know!
@gysit there are issues with the patch.
In this example (which is what we're seeing):
func @shape_only(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?xf32> {
%cst = arith.constant 0.0 : f32
%0 = linalg.fill(%cst, %arg1) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%1 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
outs(%0 : tensor<?x?xf32>) -> tensor<?x?xf32>
%2 = linalg.elemwise_unary {fun = #linalg.unary_fn<exp>}
ins(%1 : tensor<?x?xf32>) outs(%1 : tensor<?x?xf32>) -> tensor<?x?xf32>
return %2 : tensor<?x?xf32>
}
After tileAndFuse, we would see two matmul ops. One is at outside and passed as iter_types, and another is fused as expected. We are not able to remove the outer fill and matmul op.
#map0 = affine_map<(d0)[s0] -> (32, -d0 + s0)>
#map1 = affine_map<(d0, d1)[s0] -> (d0, -d1 + s0)>
module {
func @shape_only(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
%c32 = arith.constant 32 : index
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%cst = arith.constant 0.000000e+00 : f32
%0 = linalg.fill(%cst, %arg1) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%1 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%0 : tensor<?x?xf32>) -> tensor<?x?xf32>
%2 = tensor.dim %1, %c0 : tensor<?x?xf32>
%3 = tensor.dim %1, %c1 : tensor<?x?xf32>
%4 = scf.for %arg2 = %c0 to %2 step %c32 iter_args(%arg3 = %1) -> (tensor<?x?xf32>) {
%5 = scf.for %arg4 = %c0 to %3 step %c32 iter_args(%arg5 = %arg3) -> (tensor<?x?xf32>) {
%6 = affine.min #map0(%arg2)[%2]
%7 = affine.min #map0(%arg4)[%3]
%8 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
%9 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
%10 = affine.min #map1(%6, %arg2)[%8]
%11 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
%12 = tensor.extract_slice %arg0[%arg2, 0] [%10, %11] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
%13 = tensor.dim %arg1, %c0 : tensor<?x?xf32>
%14 = affine.min #map1(%7, %arg4)[%9]
%15 = tensor.extract_slice %arg1[0, %arg4] [%13, %14] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
%16 = affine.min #map1(%6, %arg2)[%8]
%17 = affine.min #map1(%7, %arg4)[%9]
%18 = tensor.dim %arg1, %c0 : tensor<?x?xf32>
%19 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
%20 = affine.min #map1(%16, %arg2)[%18]
%21 = affine.min #map1(%17, %arg4)[%19]
%22 = tensor.extract_slice %arg1[%arg2, %arg4] [%20, %21] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
%23 = linalg.fill(%cst, %22) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%24 = linalg.matmul ins(%12, %15 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%23 : tensor<?x?xf32>) -> tensor<?x?xf32>
%25 = affine.min #map0(%arg2)[%2]
%26 = affine.min #map0(%arg4)[%3]
%27 = tensor.extract_slice %arg5[%arg2, %arg4] [%25, %26] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
%28 = linalg.elemwise_unary {fun = #linalg.unary_fn<exp>} ins(%24 : tensor<?x?xf32>) outs(%27 : tensor<?x?xf32>) -> tensor<?x?xf32>
%29 = tensor.insert_slice %28 into %arg5[%arg2, %arg4] [%25, %26] [1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
scf.yield %29 : tensor<?x?xf32>
}
scf.yield %5 : tensor<?x?xf32>
}
return %4 : tensor<?x?xf32>
}
}
This causes further issues. If we don't specify anchor-op
, the outer fill + matmul
will be applied tileAndFuse again. Are we able to prevent this pattern in advance?
That seems like a bug. You are taking the result of matmul and performing matmul again. Thats a correctness issue.
This is not a correctness issue. The inner fill op takes a slice of inputs (i.e., %arg1
). After computations, we store the result into iter_arg (i.e., %arg5
). The outer fill + matmul is more like doing unnecessary initialization. This looks like a dead-end to me. I can't imagine what initial iter_arg should be in case of having matmul result in outs.
=====
Actually, it is not a normal form to fusion and IREE. Maybe we'd like to update fusion logic or run some analysis during fusion. Filling into function arguments is not common in IREE. The below example is a normal form that generated from mhlo level.
func @dot_384x512x128_exp_dispatch_0() {
%cst = arith.constant 0.000000e+00 : f32
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(32) : !flow.dispatch.tensor<readonly:384x512xf32>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(32) : !flow.dispatch.tensor<readonly:512x128xf32>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(32) : !flow.dispatch.tensor<readonly:384x128xf32>
%3 = hal.interface.binding.subspan set(0) binding(3) type(storage_buffer) offset(%c0) alignment(32) : !flow.dispatch.tensor<writeonly:384x128xf32>
%4 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [384, 512], strides = [1, 1] : !flow.dispatch.tensor<readonly:384x512xf32> -> tensor<384x512xf32>
%5 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [512, 128], strides = [1, 1] : !flow.dispatch.tensor<readonly:512x128xf32> -> tensor<512x128xf32>
%6 = flow.dispatch.tensor.load %2, offsets = [0, 0], sizes = [384, 128], strides = [1, 1] : !flow.dispatch.tensor<readonly:384x128xf32> -> tensor<384x128xf32>
%7 = linalg.init_tensor [384, 128] : tensor<384x128xf32>
%8 = linalg.fill(%cst, %7) : f32, tensor<384x128xf32> -> tensor<384x128xf32>
%9 = linalg.matmul ins(%4, %5 : tensor<384x512xf32>, tensor<512x128xf32>) outs(%8 : tensor<384x128xf32>) -> tensor<384x128xf32>
%10 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%9, %6 : tensor<384x128xf32>, tensor<384x128xf32>) outs(%7 : tensor<384x128xf32>) {
^bb0(%arg0: f32, %arg1: f32, %arg2: f32):
%11 = arith.addf %arg0, %arg1 : f32
linalg.yield %11 : f32
} -> tensor<384x128xf32>
flow.dispatch.tensor.store %10, %3, offsets = [0, 0], sizes = [384, 128], strides = [1, 1] : tensor<384x128xf32> -> !flow.dispatch.tensor<writeonly:384x128xf32>
return
}
The fill op and generic op both takes an init_tensor as outs. They are taking shaped op (i.e., init_tensor) before destination passing style: https://gist.githubusercontent.com/hanhanW/bca41ac3dcf92c69765cfa714347fbe2/raw
After destination pass, they basically take the same tensor value (in different variables). IMO, we have enough information during fusion. The fusion could have a mapping table. Once the scf.for loops created, some of the operations would take iter_args. If we have a lookup table which maintains the mapping between iter_args <-> tensors. We are able to re-use the iter_args when tiling other operands. This needs non-trivial fix in fusion or better fusion algorithm. @gysit should weigh in if we're going with this approach.
Another approach is to make bufferization smarter. It's more like recovering the missing information to me. I'm in favor of doing it right in the first place.
For both approach, it would be good if we can write down the IR that satisfies bufferization's requirement.
I think lookup table approach may be a viable solution. I will talk to @matthias-springer on Monday if the resulting IR would work with bufferization!
The approach actually introduces other issues.
Input IR:
func @add_dispatch_0() {
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:32x48xf32>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:32x48xf32>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<writeonly:32x48xf32>
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [32, 48], strides = [1, 1] : !flow.dispatch.tensor<readonly:32x48xf32> -> tensor<32x48xf32>
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [32, 48], strides = [1, 1] : !flow.dispatch.tensor<readonly:32x48xf32> -> tensor<32x48xf32>
%5 = linalg.init_tensor [32, 48] : tensor<32x48xf32>
%6 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%3, %4 : tensor<32x48xf32>, tensor<32x48xf32>) outs(%5 : tensor<32x48xf32>) {
^bb0(%arg0: f32, %arg1: f32, %arg2: f32):
%7 = arith.addf %arg0, %arg1 : f32
linalg.yield %7 : f32
} -> tensor<32x48xf32>
flow.dispatch.tensor.store %6, %2, offsets = [0, 0], sizes = [32, 48], strides = [1, 1] : tensor<32x48xf32> -> !flow.dispatch.tensor<writeonly:32x48xf32>
return
}
After moving input operand to output operand, the tile and fuse pass would create iter_args.
scf.for %arg0 = %3 to %c32 step %4 {
scf.for %arg1 = %5 to %c48 step %6 {
%7 = flow.dispatch.tensor.load %0, offsets = [%arg0, %arg1], sizes = [16, 16], strides = [1, 1] : !flow.dispatch.tensor<readonly:32x48xf32> -> tensor<16x16xf32>
%8 = flow.dispatch.tensor.load %1, offsets = [%arg0, %arg1], sizes = [16, 16], strides = [1, 1] : !flow.dispatch.tensor<readonly:32x48xf32> -> tensor<16x16xf32>
%9 = scf.for %arg2 = %c0 to %c16 step %c1 iter_args(%arg3 = %7) -> (tensor<16x16xf32>) {
%10 = scf.for %arg4 = %c0 to %c16 step %c4 iter_args(%arg5 = %arg3) -> (tensor<16x16xf32>) {
%11 = tensor.extract_slice %8[%arg2, %arg4] [1, 4] [1, 1] : tensor<16x16xf32> to tensor<1x4xf32>
%12 = tensor.extract_slice %arg5[%arg2, %arg4] [1, 4] [1, 1] : tensor<16x16xf32> to tensor<1x4xf32>
%13 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%11 : tensor<1x4xf32>) outs(%12 : tensor<1x4xf32>) attrs = {lowering.config = #iree_codegen.lowering.config<tile_sizes = [[16, 16], [1, 4], [0, 0]], native_vector_size = []>} {
^bb0(%arg6: f32, %arg7: f32):
%15 = arith.addf %arg7, %arg6 : f32
linalg.yield %15 : f32
} -> tensor<1x4xf32>
%14 = tensor.insert_slice %13 into %arg5[%arg2, %arg4] [1, 4] [1, 1] : tensor<1x4xf32> into tensor<16x16xf32>
scf.yield %14 : tensor<16x16xf32>
}
scf.yield %10 : tensor<16x16xf32>
}
flow.dispatch.tensor.store %9, %2, offsets = [%arg0, %arg1], sizes = [16, 16], strides = [1, 1] : tensor<16x16xf32> -> !flow.dispatch.tensor<writeonly:32x48xf32>
}
}
This makes bufferization allocates additional buffer even they are all vectorized.
func @add_dispatch_0() {
%c48 = arith.constant 48 : index
%c32 = arith.constant 32 : index
%c0 = arith.constant 0 : index
%c16 = arith.constant 16 : index
%c4 = arith.constant 4 : index
%c1 = arith.constant 1 : index
%cst = arith.constant 0.000000e+00 : f32
%0 = memref.alloca() {alignment = 128 : i64} : memref<16x16xf32>
%1 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64) : memref<32x48xf32>
memref.assume_alignment %1, 64 : memref<32x48xf32>
%2 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:32x48xf32>
%3 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(64) : memref<32x48xf32>
memref.assume_alignment %3, 64 : memref<32x48xf32>
%4 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:32x48xf32>
%5 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(64) : memref<32x48xf32>
memref.assume_alignment %5, 64 : memref<32x48xf32>
%6 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<writeonly:32x48xf32>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%workgroup_count_x = hal.interface.workgroup.count[0] : index
%workgroup_id_y = hal.interface.workgroup.id[1] : index
%workgroup_count_y = hal.interface.workgroup.count[1] : index
%7 = affine.apply affine_map<()[s0] -> (s0 * 16)>()[%workgroup_id_y]
%8 = affine.apply affine_map<()[s0] -> (s0 * 16)>()[%workgroup_count_y]
%9 = affine.apply affine_map<()[s0] -> (s0 * 16)>()[%workgroup_id_x]
%10 = affine.apply affine_map<()[s0] -> (s0 * 16)>()[%workgroup_count_x]
scf.for %arg0 = %7 to %c32 step %8 {
scf.for %arg1 = %9 to %c48 step %10 {
%11 = memref.subview %1[%arg0, %arg1] [16, 16] [1, 1] : memref<32x48xf32> to memref<16x16xf32, affine_map<(d0, d1)[s0] -> (d0 * 48 + s0 + d1)>>
%12 = memref.subview %3[%arg0, %arg1] [16, 16] [1, 1] : memref<32x48xf32> to memref<16x16xf32, affine_map<(d0, d1)[s0] -> (d0 * 48 + s0 + d1)>>
linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%11 : memref<16x16xf32, affine_map<(d0, d1)[s0] -> (d0 * 48 + s0 + d1)>>) outs(%0 : memref<16x16xf32>) {
^bb0(%arg2: f32, %arg3: f32):
linalg.yield %arg2 : f32
}
scf.for %arg2 = %c0 to %c16 step %c1 {
scf.for %arg3 = %c0 to %c16 step %c4 {
%14 = vector.transfer_read %12[%arg2, %arg3], %cst {in_bounds = [true, true]} : memref<16x16xf32, affine_map<(d0, d1)[s0] -> (d0 * 48 + s0 + d1)>>, vector<1x4xf32>
%15 = vector.transfer_read %0[%arg2, %arg3], %cst {in_bounds = [true, true]} : memref<16x16xf32>, vector<1x4xf32>
%16 = arith.addf %15, %14 : vector<1x4xf32>
vector.transfer_write %16, %0[%arg2, %arg3] {in_bounds = [true, true]} : vector<1x4xf32>, memref<16x16xf32>
}
}
%13 = memref.subview %5[%arg0, %arg1] [16, 16] [1, 1] : memref<32x48xf32> to memref<16x16xf32, affine_map<(d0, d1)[s0] -> (d0 * 48 + s0 + d1)>>
linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%0 : memref<16x16xf32>) outs(%13 : memref<16x16xf32, affine_map<(d0, d1)[s0] -> (d0 * 48 + s0 + d1)>>) {
^bb0(%arg2: f32, %arg3: f32):
linalg.yield %arg2 : f32
}
}
}
return
}
I think we can try to vectorize all the operations. If it works, we don't need the patch. I'll take a stab at it.
We missed to check if the inputs are from readonly tensor. The following post-fix https://github.com/google/iree/pull/8499 addressed the issue.
I think that works today. The problem comes in cases where they dont vectorize because of dynamic dimensions. On those paths, it would be good to be able to bufferization without stack allocations. I know sandbox is exploring padding and peeling options to always vectorize. Those done work with fusion (at least not to the extent they need to be for what is seen in IREE). So they arent directly usable yet in IREE.
Hi @MaheshRavishankar could you please give me more insight on these exploratory directions to use padding and peeling to vectorize? I am attempting to do something similar. Do you have any of this code in IREE? Would you mind pointing me to it? Thanks a lot.
This looks like the issue I've hit before. If the ops are not vectorized, we will see
memref.alloca
ops. I think we should have the same behavior between all the ops get vectorized and all the ops not get vectorized. I feel that we miss some propagation.To repro:
IR before bufferization:
If we can vectorize all the operations, we'll see a pure form and things go well in bufferization. E.g.,
I think we have to fix the issue. In practice, not all the operations are vectorizable even we can pad everything correctly. E.g., the gather op is not vectorizable. We'll see vectorizable matmul + gather in some models. In this context, we have to figure out how to propagate the buffer aliasing.
@matthias-springer do you have suggestions about not generating memref.alloca ops in this case?