Open powderluv opened 9 months ago
@erwei-xilinx , @vivek-amd and a few of us had a discussion on some related aspects. I think we need to have a bit more concrete on the scheduling we setup using tiling + packing transformations. Probably worth discussing in our sync tomorrow. @jtuyls would be good to have you chip in as well.
@MaheshRavishankar , @yzhang93 , @stephenneuendorffer In our discussion earlier today, we identified issue with the current proposed IR driving the MLIR-AIR's scheduling workflows. I am putting together an alternative scheduling strategy, where the new frontend IR for AIR is attached here: https://gist.gitenterprise.xilinx.com/erweiw/5bf859b389ee009223e5dfab21bbef07 Comments are greatly appreciated.
@erwei-xilinx , @vivek-amd and a few of us had a discussion on some related aspects. I think we need to have a bit more concrete on the scheduling we setup using tiling + packing transformations. Probably worth discussing in our sync tomorrow. @jtuyls would be good to have you chip in as well.
Hi @MaheshRavishankar, I'm not sure if I'm tagged here by mistake or intentionally? Since, I'm not aware of any such discussion.
@erwei-xilinx , @vivek-amd and a few of us had a discussion on some related aspects. I think we need to have a bit more concrete on the scheduling we setup using tiling + packing transformations. Probably worth discussing in our sync tomorrow. @jtuyls would be good to have you chip in as well.
Hi @MaheshRavishankar, I'm not sure if I'm tagged here by mistake or intentionally? Since, I'm not aware of any such discussion.
Hi @vivekkhandelwal1 sorry, it's a mistake. I should be the one to tag.
Notes from meeting to get to the "streaming K" workflow.
1) Tile using scf.forall for first level tiling (for now a 2x2 logical view of the AIE cores) 2) Tile using scf.forall for second level tiling (from 2x2 -> 1x1 core) 3) Promote the output operand of the fill + matmul to L1 memory 4) Tile the K dimension of matmul with tile size to 256. 5) Promote the LHS and RHS operands of the matmul to L2 memory (using pad not pack) 6) Tile the K dimension with tile size 32 7) Promote the LHS, RHS and result using pack to get an inner matmul of size 4x8x8 (MxKxN)
Lets checkpoint after these steps. Will probably need some pack hoisting to remove some unnecessary memory.
Thanks @MaheshRavishankar! I will try as discussed. Just want to archive what I have right now, two versions of transformed IR.
Notes from meeting to get to the "streaming K" workflow.
- Tile using scf.forall for first level tiling (for now a 2x2 logical view of the AIE cores)
- Tile using scf.forall for second level tiling (from 2x2 -> 1x1 core)
- Promote the output operand of the fill + matmul to L1 memory
- Tile the K dimension of matmul with tile size to 256.
- Promote the LHS and RHS operands of the matmul to L2 memory (using pad not pack)
- Tile the K dimension with tile size 32
- Promote the LHS, RHS and result using pack to get an inner matmul of size 4x8x8 (MxKxN)
Lets checkpoint after these steps. Will probably need some pack hoisting to remove some unnecessary memory.
I followed the proposed workflow and generate the IR before bufferization. To make this useful, we'll have to add a pass to hoist pack and unpack ops out of the scf.for loops.
@MaheshRavishankar @yzhang93 IR dump for (2k, 2k, 2k, i32) GEMM, showing K-streaming schedule: https://gist.github.com/erwei-xilinx/f1482b2b1c796e92be313a5091ab42e6
@MaheshRavishankar @yzhang93 IR dump for (2k, 2k, 2k, i32) GEMM, showing K-streaming schedule: https://gist.gitenterprise.xilinx.com/erweiw/524d374e4d2741e70c8188fbc7732cc7
The link is behind a firewall. Please use gist.github.com
The link is behind a firewall. Please use gist.github.com
Moved the gist to GitHub: https://gist.github.com/erwei-xilinx/f1482b2b1c796e92be313a5091ab42e6
In any case, I finally got some time to take a look at the schedule. After some playing around with things, I can now generate this output
#map = affine_map<(d0) -> (d0 * 64)>
#map1 = affine_map<(d0) -> (d0 * 32)>
#map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d0, d3, d5)>
#map3 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d5, d4)>
#map4 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d0, d3, d4)>
module {
func.func @matmul_i32(%arg0: tensor<1024x2048xi32>, %arg1: tensor<2048x512xi32>) -> tensor<1024x512xi32> {
%c0_i32 = arith.constant 0 : i32
%c0 = arith.constant 0 : index
%c2048 = arith.constant 2048 : index
%c256 = arith.constant 256 : index
%c32 = arith.constant 32 : index
%0 = bufferization.to_memref %arg1 : memref<2048x512xi32, strided<[?, ?], offset: ?>>
%1 = bufferization.to_memref %arg0 : memref<1024x2048xi32, strided<[?, ?], offset: ?>>
%alloca = memref.alloca() {alignment = 64 : i64} : memref<1024x512xi32>
scf.forall (%arg2, %arg3) in (16, 8) {
%3 = affine.apply #map(%arg2)
%4 = affine.apply #map(%arg3)
%subview = memref.subview %alloca[%3, %4] [64, 64] [1, 1] : memref<1024x512xi32> to memref<64x64xi32, strided<[512, 1], offset: ?>>
%alloc = memref.alloc() : memref<64x64xi32, 1>
scf.forall (%arg4, %arg5) in (2, 2) {
%5 = affine.apply #map1(%arg4)
%6 = affine.apply #map1(%arg5)
%subview_0 = memref.subview %1[%3, 0] [64, 2048] [1, 1] : memref<1024x2048xi32, strided<[?, ?], offset: ?>> to memref<64x2048xi32, strided<[?, ?], offset: ?>>
%subview_1 = memref.subview %subview_0[%5, 0] [32, 2048] [1, 1] : memref<64x2048xi32, strided<[?, ?], offset: ?>> to memref<32x2048xi32, strided<[?, ?], offset: ?>>
%subview_2 = memref.subview %0[0, %4] [2048, 64] [1, 1] : memref<2048x512xi32, strided<[?, ?], offset: ?>> to memref<2048x64xi32, strided<[?, ?], offset: ?>>
%subview_3 = memref.subview %subview_2[0, %6] [2048, 32] [1, 1] : memref<2048x64xi32, strided<[?, ?], offset: ?>> to memref<2048x32xi32, strided<[?, ?], offset: ?>>
%subview_4 = memref.subview %alloc[%5, %6] [32, 32] [1, 1] : memref<64x64xi32, 1> to memref<32x32xi32, strided<[64, 1], offset: ?>, 1>
linalg.fill ins(%c0_i32 : i32) outs(%subview_4 : memref<32x32xi32, strided<[64, 1], offset: ?>, 1>)
scf.for %arg6 = %c0 to %c2048 step %c256 {
%subview_5 = memref.subview %subview_1[0, %arg6] [32, 256] [1, 1] : memref<32x2048xi32, strided<[?, ?], offset: ?>> to memref<32x256xi32, strided<[?, ?], offset: ?>>
%subview_6 = memref.subview %subview_3[%arg6, 0] [256, 32] [1, 1] : memref<2048x32xi32, strided<[?, ?], offset: ?>> to memref<256x32xi32, strided<[?, ?], offset: ?>>
scf.for %arg7 = %c0 to %c256 step %c32 {
%alloc_7 = memref.alloc() : memref<32x256xi32, 1>
linalg.copy ins(%subview_5 : memref<32x256xi32, strided<[?, ?], offset: ?>>) outs(%alloc_7 : memref<32x256xi32, 1>)
%subview_8 = memref.subview %alloc_7[0, %arg7] [32, 32] [1, 1] : memref<32x256xi32, 1> to memref<32x32xi32, strided<[256, 1], offset: ?>, 1>
%alloc_9 = memref.alloc() : memref<256x32xi32, 1>
linalg.copy ins(%subview_6 : memref<256x32xi32, strided<[?, ?], offset: ?>>) outs(%alloc_9 : memref<256x32xi32, 1>)
%subview_10 = memref.subview %alloc_9[%arg7, 0] [32, 32] [1, 1] : memref<256x32xi32, 1> to memref<32x32xi32, strided<[32, 1], offset: ?>, 1>
%alloc_11 = memref.alloc() : memref<4x8x4x8xi32, 2>
iree_linalg_ext.pack %subview_8 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [4, 8] into %alloc_11 : (memref<32x32xi32, strided<[256, 1], offset: ?>, 1> memref<4x8x4x8xi32, 2>)
%alloc_12 = memref.alloc() : memref<4x4x8x8xi32, 2>
iree_linalg_ext.pack %subview_10 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %alloc_12 : (memref<32x32xi32, strided<[32, 1], offset: ?>, 1> memref<4x4x8x8xi32, 2>)
%alloc_13 = memref.alloc() : memref<4x8x4x8xi32, 2>
iree_linalg_ext.pack %subview_4 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [4, 8] into %alloc_13 : (memref<32x32xi32, strided<[64, 1], offset: ?>, 1> memref<4x8x4x8xi32, 2>)
linalg.generic {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%alloc_11, %alloc_12 : memref<4x8x4x8xi32, 2>, memref<4x4x8x8xi32, 2>) outs(%alloc_13 : memref<4x8x4x8xi32, 2>) {
^bb0(%in: i32, %in_14: i32, %out: i32):
%7 = arith.muli %in, %in_14 : i32
%8 = arith.addi %out, %7 : i32
linalg.yield %8 : i32
}
iree_linalg_ext.unpack %alloc_13 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [4, 8] into %subview_4 : (memref<4x8x4x8xi32, 2> memref<32x32xi32, strided<[64, 1], offset: ?>, 1>)
memref.dealloc %alloc_13 : memref<4x8x4x8xi32, 2>
memref.dealloc %alloc_11 : memref<4x8x4x8xi32, 2>
memref.dealloc %alloc_12 : memref<4x4x8x8xi32, 2>
memref.dealloc %alloc_7 : memref<32x256xi32, 1>
memref.dealloc %alloc_9 : memref<256x32xi32, 1>
}
}
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
linalg.copy ins(%alloc : memref<64x64xi32, 1>) outs(%subview : memref<64x64xi32, strided<[512, 1], offset: ?>>)
memref.dealloc %alloc : memref<64x64xi32, 1>
} {mapping = [#gpu.block<y>, #gpu.block<x>]}
%2 = bufferization.to_tensor %alloca : memref<1024x512xi32>
return %2 : tensor<1024x512xi32>
}
I think this is the schedule we want for the most part. The only missing piece here is that the fill
is happening at L2 memory. I think that can be fixed by a custom fusion pass (while I add support for that upstream).
This is the transform dialect script I used to get this output
func.func @matmul_i32(%arg0 : tensor<1024x2048xi32>, %arg1 : tensor<2048x512xi32>) -> tensor<1024x512xi32> {
%c0 = arith.constant 0: i32
%empty = tensor.empty() : tensor<1024x512xi32>
%0 = linalg.fill ins(%c0 : i32) outs(%empty : tensor<1024x512xi32>) -> tensor<1024x512xi32>
%1 = linalg.matmul ins(%arg0, %arg1 : tensor<1024x2048xi32>, tensor<2048x512xi32>)
outs(%0 : tensor<1024x512xi32>) -> tensor<1024x512xi32>
return %1 : tensor<1024x512xi32>
}
module attributes { transform.with_named_sequence } {
transform.named_sequence @cleanup(%variant_op: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func {
transform.apply_patterns.linalg.tiling_canonicalization
transform.apply_patterns.iree.fold_fill_into_pad
transform.apply_patterns.scf.for_loop_canonicalization
transform.apply_patterns.canonicalization
} : !transform.any_op
transform.iree.apply_licm %func : !transform.any_op
transform.apply_cse to %func : !transform.any_op
transform.yield
}
transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.read_only}) {
%ops = transform.structured.match ops{["linalg.fill", "linalg.matmul"]} in %variant_op : (!transform.any_op) -> !transform.any_op
%fill, %matmul = transform.split_handle %ops : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
// First level tile to forall.
%first_level_tiled_matmul, %outer_forall =
transform.structured.tile_using_forall %matmul tile_sizes [64, 64]
( mapping = [#gpu.block<y>, #gpu.block<x>] ) : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
// Fuse fill operation into the forall loop.
%0, %1 = transform.structured.fuse_into_containing_op %fill into %outer_forall : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
// Pad operation.
%first_level_tiled_padded_matmul, %pad, %__ = transform.structured.pad %first_level_tiled_matmul {
padding_values=[0 : i32, 0 : i32, 0 : i32],
padding_dimensions=[0, 1, 2],
pack_paddings=[0, 0, 1],
copy_back_op="linalg.copy"
} : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
%pad_dps = transform.structured.rewrite_in_destination_passing_style %pad : (!transform.any_op) -> !transform.any_op
// Promote the result to shared memrory
%2 = transform.get_producer_of_operand %first_level_tiled_padded_matmul[2] : (!transform.any_op) -> (!transform.any_op)
%3, %4 = transform.structured.bufferize_to_allocation %2
{memory_space = 1, bufferize_destination_only, emit_dealloc} : !transform.any_op
// Second level tile to forall with tile_sizes.
%second_level_tiled_matmul, %inner_forall =
transform.structured.tile_using_forall %first_level_tiled_padded_matmul tile_sizes [32, 32]
( mapping = [#gpu.thread<y>, #gpu.thread<x>] ) : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
// Clean up.
transform.include @cleanup failures(propagate) (%variant_op) : (!transform.any_op) -> ()
// Find the copies for the operands.
%lhs_slice = transform.get_producer_of_operand %second_level_tiled_matmul[0] : (!transform.any_op) -> (!transform.any_op)
%lhs_copy = transform.get_producer_of_operand %lhs_slice[0] : (!transform.any_op) -> (!transform.any_op)
%rhs_slice = transform.get_producer_of_operand %second_level_tiled_matmul[1] : (!transform.any_op) -> (!transform.any_op)
%rhs_copy = transform.get_producer_of_operand %rhs_slice[0] : (!transform.any_op) -> (!transform.any_op)
// Fuse the copies into the inner loop.
%5, %6 = transform.structured.fuse_into_containing_op %lhs_copy into %inner_forall : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
%7, %8 = transform.structured.fuse_into_containing_op %rhs_copy into %inner_forall : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
// Fuse fill operation into the forall loop.
%first_level_fused_fill = transform.structured.match ops{["linalg.fill"]} in %variant_op : (!transform.any_op) -> !transform.any_op
%second_level_fused_fill, %9 = transform.structured.fuse_into_containing_op %first_level_fused_fill into %inner_forall : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
// First level for loop.
%first_level_tiled_reduction_matmul, %outer_for_loop =
transform.structured.tile_using_for %second_level_tiled_matmul [0, 0, 256]
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
// Second level for loop.
%second_level_tiled_reduction_matmul, %inner_for_loop =
transform.structured.tile_using_for %first_level_tiled_reduction_matmul [0, 0, 32]
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
// Pack by applying data tiling, and the linalg.matmul becomes linalg.generic.
%packed = transform.structured.pack %second_level_tiled_reduction_matmul packed_sizes = [4, 8, 8]
: (!transform.any_op) -> (!transform.any_op)
// Transpose A matrix from [M K m k m0 k0] to [M K k m m0 k0]
%packed_lhs = transform.get_producer_of_operand %packed[0]
: (!transform.any_op) -> (!transform.any_op)
%lhs_packed_matmul, %lhs_pack_op, %lhs_unpack_op =
transform.structured.pack_transpose %packed_lhs with_compute_op(%packed)
outer_perm = [1, 0] : (!transform.any_op, !transform.any_op)
-> (!transform.any_op, !transform.any_op, !transform.any_op)
// Transpose B matrix from [K N k n n0 k0] to [K N n k k0 n0]
%packed_rhs = transform.get_producer_of_operand %lhs_packed_matmul[1]
: (!transform.any_op) -> (!transform.any_op)
%operands_packed_matmul, %rhs_pack_op, %rhs_unpack_op =
transform.structured.pack_transpose %packed_rhs with_compute_op(%lhs_packed_matmul)
outer_perm = [1, 0] inner_perm = [1, 0] : (!transform.any_op, !transform.any_op)
-> (!transform.any_op, !transform.any_op, !transform.any_op)
// Transpose C matrix from [M N m n m0 n0] to [M N n m m0 n0]
%packed_output = transform.get_consumers_of_result %operands_packed_matmul[0]
: (!transform.any_op) -> (!transform.any_op)
%packed_matmul, %output_pack_op, %output_unpack_op =
transform.structured.pack_transpose %packed_output with_compute_op(%operands_packed_matmul)
outer_perm = [1, 0] : (!transform.any_op, !transform.any_op)
-> (!transform.any_op, !transform.any_op, !transform.any_op)
// Bufferize result to local memory allocation
%buffer_c, %new_c = transform.structured.bufferize_to_allocation %output_pack_op
{memory_space = 2, bufferize_destination_only, emit_dealloc} : !transform.any_op
// Promote the inputs to local memory.
%buffer_a, %new_a = transform.structured.bufferize_to_allocation %lhs_pack_op
{memory_space = 2, bufferize_destination_only, emit_dealloc} : !transform.any_op
%buffer_b, %new_b = transform.structured.bufferize_to_allocation %rhs_pack_op
{memory_space = 2, bufferize_destination_only, emit_dealloc} : !transform.any_op
// Promote the inputs to the pack to shared memory.
%lhs_pack_op_source = transform.get_producer_of_operand %lhs_pack_op[0] : (!transform.any_op) -> (!transform.any_op)
%lhs_pack_op_source_buffer, %lhs_pack_op_source_new = transform.structured.bufferize_to_allocation %lhs_pack_op_source
{memory_space = 1, bufferize_destination_only, memcpy_op = "linalg.copy", emit_dealloc } : !transform.any_op
%rhs_pack_op_source = transform.get_producer_of_operand %rhs_pack_op[0] : (!transform.any_op) -> (!transform.any_op)
%rhs_pack_op_source_buffer, %rhs_pack_op_source_new = transform.structured.bufferize_to_allocation %rhs_pack_op_source
{memory_space = 1, bufferize_destination_only, memcpy_op = "linalg.copy", emit_dealloc} : !transform.any_op
// Clean up.
transform.include @cleanup failures(propagate) (%variant_op) : (!transform.any_op) -> ()
transform.iree.eliminate_empty_tensors %variant_op : (!transform.any_op) -> ()
// Bufferize and drop HAL decriptor from memref ops.
%variant_op_3 = transform.iree.bufferize %variant_op : (!transform.any_op) -> !transform.any_op
transform.yield
}
}
You should be able to get the output with
iree-opt --iree-transform-dialect-interpreter --canonicalize --cse
cc @jtuyls as well. WDYT?
@MaheshRavishankar This IR has the L3->L2 linalg.copy
ops inside both scf.for loops, which we would actually want inside the outer for loop, but outside the inner for loop. Otherwise it is happening too frequently than the actual schedule we want.
Inside the inner scf.for loop, there are three L2->L1 packs and one L1->L2 unpack, whereas we would want to hoist the 3rd pack (fusing with linalg.fill) and the unpack outside of both scf for loops, so that we get the schedule of "only copying the results out once the entire K reduction is done".
I think once the above two points, plus the linalg.fill fusion into L1, are modified, then we should get the schedule that we are aiming for.
Thanks @MaheshRavishankar! @erwei-xilinx and I had several rounds of trial with different IRs, and found there are certain structure constraints we need to satisfy to make the large GEMM run through AIR. To summarize what Erwei just pointed out, we need to make sure:
Initialization should stay in L1 with correct packed shape linalg.fill ins(%c0_i32 : i32) outs(%alloc_3 : memref<4x8x4x8xi32, 2>)
. We should avoid moving outputs back in from L3->L2->L1.
Input data movement from L3->L2 and L2->L1 should happen under separate scf.for loops with K dimension tiled from 2048->256 and 256->32. Only the inputs are copied into L2 and L1, so there should be two copies/packs within each scf.for loop.
Here's the current IR generated by pad-pack pipeline https://gist.github.com/yzhang93/b4cf2bfa8f4a0ab0b09c459b368fcffe.
The pack/pad IR written by hand which shows the ideal structure https://gist.github.com/yzhang93/025365c515350e04ea97a2c49bc32d79.
Thanks @erwei-xilinx and @yzhang93 . There are few artifacts here I think we need to look at that are going to make it harder (and I think these are being handled within AIR using some non-trivial dependence tracking which is what I am really trying to avoid).
https://gist.github.com/yzhang93/b4cf2bfa8f4a0ab0b09c459b368fcffe#file-pad-pack-pipeline-ir-L38
There is a loop of copies here but it is really hard to know that every copy within this loop is servicing data that can be used to kick off the compute. This all seems to be very implicit and I am trying to avoid that.
Your points are right though, but I am trying to avoid a path that requires some non-trivial program state tracking. Things that we can even do one off if we needed to 1) Move the initialization into L1. Thats what I was refering to as one thing that I havent done, and I think a pass to fuse the fill into the inner-most for with a conditional would do the trick 2) And the pack in the innermost loop needs to be hoisted.
I do need to get the right shape of the data movement from L2 to L1. Let me look into that a bit more.
There are few artifacts here I think we need to look at that are going to make it harder (and I think these are being handled within AIR using some non-trivial dependence tracking which is what I am really trying to avoid). https://gist.github.com/yzhang93/b4cf2bfa8f4a0ab0b09c459b368fcffe#file-pad-pack-pipeline-ir-L38 There is a loop of copies here but it is really hard to know that every copy within this loop is servicing data that can be used to kick off the compute. This all seems to be very implicit and I am trying to avoid that.
I agree and I think it results in what's happening on AIE being different from what the IR initially described. @erwei-xilinx is this decision made inside the AIRDependency
pass? I see a transformation there adding an async token + yield:
scf.for %arg12 = %c0_1 to %c2048 step %c256 {
air.dma_memcpy_nd (%alloc[%c0_1, %arg12] [%c8, %c256] [%c2048, %c1_0], %arg10[%3, %arg12] [%c8, %c256] [%c2048, %c1_0]) {id = 1 : i32} : (memref<8x2048xi32, 1 : i32>, memref<8x2048xi32>)
}
into
%3 = scf.for %arg12 = %c0_8 to %c2048 step %c256 iter_args(%arg13 = %2) -> (!air.async.token) {
%c0_22 = arith.constant 0 : index
%c8_23 = arith.constant 8 : index
%c256_24 = arith.constant 256 : index
%c2048_25 = arith.constant 2048 : index
%c1_26 = arith.constant 1 : index
%8 = air.dma_memcpy_nd async [%arg13, %arg13] (%results_14[%c0_22, %arg12] [%c8_23, %c256_24] [%c2048_25, %c1_26], %arg10[%results_10, %arg12] [%c8_23, %c256_24] [%c2048_25, %c1_26]) {id = 1 : i32} : (memref<8x2048xi32, 1 : i32>, memref<8x2048xi32>)
%9 = air.wait_all async [%arg13, %8] {id = 1 : i32}
scf.yield %9 : !air.async.token
}
From @erwei-xilinx :
Inside the inner scf.for loop, there are three L2->L1 packs and one L1->L2 unpack, whereas we would want to hoist the 3rd pack (fusing with linalg.fill) and the unpack outside of both scf for loops, so that we get the schedule of "only copying the results out once the entire K reduction is done".
From @yzhang93 :
Input data movement from L3->L2 and L2->L1 should happen under separate scf.for loops with K dimension tiled from 2048->256 and 256->32. Only the inputs are copied into L2 and L1, so there should be two copies/packs within each scf.for loop.
@yzhang93 Isn't this already done or where do we still have three copies/packs? I see two packs here on the input side in your generated IR? https://gist.github.com/yzhang93/b4cf2bfa8f4a0ab0b09c459b368fcffe#file-pad-pack-pipeline-ir-L63.
I also see the fill already on L1, correct? https://gist.github.com/yzhang93/b4cf2bfa8f4a0ab0b09c459b368fcffe#file-pad-pack-pipeline-ir-L57
Thanks for looking into the IR dump, @MaheshRavishankar . I agree with your point that this post processing needs to be avoided. I'd love to see if we can come with an alternative to replace that.
@jtuyls The pass that infers the subview of the buffer being serviced is air-loop-fusion
(https://github.com/Xilinx/mlir-air/blob/1cff164373045e5c16bcd5b25778693a4423af05/mlir/lib/Transform/AIRDependencyScheduleOpt.cpp#L3545)
I am working through a sequence that I think works, but really at a very high level, I dont see things connecting unless we do the following 1) Tile, fuse and distribute M and N dimensions spatially. 2) Tile temporally the K dimension 3) Pad everything to a multiple of (64, 64, 256) 3) Promote operands to shared memory 4) Tile and distribute M and N of matmul alone 5) Tile the K dimension again 6) Pack the operands and promote to local memory 7) Use peeling to avoid filling in shared memory.
So really we need the peeling path to work through the stack.
@MaheshRavishankar Peeling results in duplications of the L2-L1 linalg.copy
/ iree_linalg_ext.pack
ops. In AIR, we currently materialize every instance of air.dma
op into discrete data mover hardware (DMA channels and stream interconnects). It is not clear to me with peeling how to map the peeled linalg.copy
/ iree_linalg_ext.pack
ops back to share the hardware.
@MaheshRavishankar I have an example using the GPU block/thread id for scf.forall loops. Is this what we want?https://gist.github.com/yzhang93/9702c546462d16c70d451ed0e5f02cdb
We don't want the GPU launch. That needs to be an scf.for. the inner scf.forall resolution looks fine. If you want we can chat tomorrow and I can show you what I think we should do in IR. Please setup sometime for tomorrow
Most of IREEs codegeneration relies on using tensor-based approaches. To aid that the
scf.forall
operation allows you to do parallel spatial decomposition. This is some part of what is modeled byair.launch
andair.herd
. The latter are more general constructs, but for current codegeneration pipeline these constructs might be too general. This issue is to track the use of tensor based code-generation (and thereforescf.forall
) farther down the stack. Post bufferization,air.launch
andair.herd
might be useful constructs to use anyway for managing the lowering to IPU instructions.