Open matth2k opened 11 months ago
I agree that generating iteration variables may be helpful for some compiler passes. However, it is somehow not an easy job to determine whether a variable is a reduction variable from the frontend, so we currently do not support this feature. The allo.reduction
function is just an annotation, and it does not generate the loop with iteration variables.
I haven't figured out a good way to resolve this issue. Probably some sophisticated frontend analysis pass may help generate this kind of reduction loops.
I think I found a solution to this problem: https://github.com/cornell-zhang/amc-dialect/pull/64
This looks cool! Could you provide an example of the original MLIR code and the code after this pass? @andrewb1999
Yeah so this is the code before the pass:
module {
func.func @kernel(%arg0: memref<20xi32>) -> i32 attributes {itypes = "s", otypes = "s", top} {
%c0_i32 = arith.constant 0 : i32
%alloc = memref.alloc() {name = "sum"} : memref<1xi32>
affine.store %c0_i32, %alloc[0] {to = "sum"} : memref<1xi32>
affine.for %arg1 = 0 to 20 {
%1 = affine.load %arg0[%arg1] {from = "A"} : memref<20xi32>
%2 = affine.load %alloc[0] {from = "sum"} : memref<1xi32>
%3 = arith.addi %2, %1 : i32
affine.store %3, %alloc[0] {to = "sum"} : memref<1xi32>
} {loop_name = "i", op_name = "S_i_0", reduction}
%0 = affine.load %alloc[0] {from = "sum"} : memref<1xi32>
return %0 : i32
}
}
and this is the code after the pass:
module {
func.func @kernel(%arg0: memref<20xi32>) -> i32 attributes {itypes = "s", otypes = "s", top} {
%c0_i32 = arith.constant 0 : i32
%alloc = memref.alloc() {name = "sum"} : memref<1xi32>
affine.store %c0_i32, %alloc[0] {to = "sum"} : memref<1xi32>
%0 = affine.load %alloc[0] {from = "sum"} : memref<1xi32>
%1 = affine.for %arg1 = 0 to 20 iter_args(%arg2 = %0) -> (i32) {
%3 = affine.load %arg0[%arg1] {from = "A"} : memref<20xi32>
%4 = arith.addi %arg2, %3 : i32
affine.yield %4 : i32
}
affine.store %1, %alloc[0] : memref<1xi32>
%2 = affine.load %alloc[0] {from = "sum"} : memref<1xi32>
return %2 : i32
}
}
you can see the load and store on sum have been removed and replaced with iter_args and an affine.yield. The sum memref should then be able to be removed entirely using store-load forwarding.
Describe the bug Neither writing kernels with primitives like
matmul()
or usingallo.grid()
make use ofaffine.for
's ability to contain iteration arguments. For us, this is important for pipelining. Here is an example of the MLIR produced bytest_reduce()
(shown further down).To Reproduce The linalg dialect compounds the issue, because it lowers linalg to affine loops without an accumulator:
But even with an explicit accumulator in a single memref cell, I can't get it to be raised to SSA values:
Buggy output I was not hopeful that the existing MLIR passes would help with this issue, but I tried anyways by running
mlir-opt --convert-linalg-to-affine-loops --affine-scalrep --lower-affine --convert-scf-to-cf --mem2reg
It is only expected to work on unstructured control flow, but I could not get it to work for that.
Expected behavior Here is an example of how we do matmul in affine that uses iteration arguments to assist the pipelining pass:
Perhaps there are the right patterns/passes in MLIR to accomplish what we want, but I haven't found them yet. Maybe we will have to write our own pass for this or lower the AST differently.