Open MaheshRavishankar opened 2 years ago
Big +1 on generalizing the op! Thanks for the exploration. Now I understand better what's the difference between Linalg ops and linalg_ext.top_k
ops.
The main differences are:
linalg_ext.top_k
ops.We're not able to represent such ops in Linalg. Because Linalg aims to have perfect nested loops.
func.func @_topk_1d_dim0_max_dispatch_0() {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c10 = arith.constant 10 : index
%c3 = arith.constant 3 : index
%c64 = arith.constant 64 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64) : memref<10xf32>
memref.assume_alignment %0, 64 : memref<10xf32>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(64) : memref<10xi32>
memref.assume_alignment %1, 64 : memref<10xi32>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(64) : memref<3xf32>
memref.assume_alignment %2, 64 : memref<3xf32>
%3 = hal.interface.binding.subspan set(0) binding(3) type(storage_buffer) offset(%c64) alignment(64) : memref<3xi32>
memref.assume_alignment %3, 64 : memref<3xi32>
scf.for %arg0 = %c0 to %c10 step %c1 {
%4 = memref.load %0[%arg0] : memref<10xf32>
%5 = memref.load %1[%arg0] : memref<10xi32> // This will be `%arg0` if input_indices is not specified.
%6:2 = scf.for %arg1 = %c0 to %c3 step %c1 iter_args(%arg2 = %4, %arg3 = %5) -> (f32, i32) {
%7 = memref.load %2[%arg1] : memref<3xf32>
%8 = memref.load %3[%arg1] : memref<3xi32>
%9 = arith.cmpf ogt, %arg2, %7 : f32
%10 = arith.cmpf ogt, %7, %arg2 : f32
%11 = arith.cmpi eq, %9, %10 : i1
%12 = arith.cmpi slt, %arg3, %8 : i32
%13 = arith.andi %11, %12 : i1
%14 = arith.ori %9, %13 : i1
%15 = arith.select %9, %arg2, %7 : f32
%16 = arith.select %14, %arg3, %8 : i32
memref.store %15, %2[%arg1] : memref<3xf32>
memref.store %16, %3[%arg1] : memref<3xi32>
%17 = arith.select %9, %7, %arg2 : f32
%18 = arith.select %14, %8, %arg3 : i32
scf.yield %17, %18 : f32, i32
}
}
return
}
The needs are:
input_values
and input_indices
.I borrowed the idea from Mahesh and extend it a bit more. To generalize the op, we can ask linalg_ext.yield
to yield #output + #input
values. The leading #output
values will be stored into outs
, and the rest of values will be yield by scf.yield
.
%result:n = linalg_ext.maxk dimension(...)
ins(%input0, %input1, %input2, ...., %inputm : ....)
outs(%output0, %output1,... %outputn : ...) {
^bb0(%arg0, %arg1, %arg2...., %argm, %argmp1, %argmp2, ... %argmpn) :
...
linalg_ext.yield %t0, %t1, ... %tn, %y0, %y1, ..., %ym : ....
}
The approach does not satisfy the needs for marking indices optional. But it generalizes it to other cases. We'll be able to use it in ArgMin/Max lowering as well. (In this context, we no longer need stack allocation for ArgMin/Max.)
I don't have a good solution for indices at this moment. Maybe having a new block could address the issue. But it makes the op a bit more complicated. E.g.,
%result:2 = linalg_ext.maxk dimension(...)
ins(%input0: ....)
outs(%output0, %output1 : ...) {
^bb0(%in):
%idx = linalg_ext.index ...
linalg_ext.yield %in, %idx ... // these will be passed to scf.for iters.
^bb1(%arg0, %arg1, %argmp0, %argmp1) :
...
linalg_ext.yield %t0, %t1, %y0, %y1: .... // The first two will be stored into outs, and the rest will be used in the next iteration.
}
Thinking a little bit more about this, there is an advantage to making linalg_ext.maxk
an interface and linalg_ext.topk
an op that implements this interface (maxk
doesnt have to be an interface to start with but could just be the base op in Tablegen for something with low overhead). The reason for this is that linalg_ext.topk
when tiled + distributed will need a different operation as to represent the merge. That operation can be a linalg_ext.maxk
. Knowing the op is a linalg_ext.topk
(as opposed to recognizing the operation is a linalg_ext.topk
by looking at the op region) is much easier to manage.
To some extent this analogy holds : linalg_ext.topk
is to linalg_ext.maxk
like linalg.matmul
is to linalg.generic
. In other words ,if linalg_ext.maxk
were an interface similar to LinalgOp
, linalg_ext.topk
would be a named op.
So lets stay the course with linalg_ext.topk
for now to learn more about how this all shakes out.
Based on review of PR #9162 it seems to me that the same
linalg_ext.topk
is being used in two separate contexts. To me that is an indication thatlinalg_ext.topk
is very specific and needs to be generalized.Background
Current definition of
linalg_ext.topk
isIn this incarnation, the
%result#0
contains the top K values from%input_values
(the definition of top provided by the region of the op).%result#1
provides the position of these top K values in%input_values
along thedimension
specified.In another incarnation, the operation is
%result#0
is same as above.%input_indices
is the same size as%input_values
. Here%result#1
contains values from%input_indices
(instead of position in%input_values
).The second incarnation is needed for when the
dimension
(which is a reduction dimension) is distributed across threads, and tracking the original index during the merge step. The specification of this operation though is a bit strange (IMO). It seems to suggest that the op-definition is being massaged to fit the eventual need, but the computation itself does not seem to represent what one would expect from atopk
.Suggestion
It seems to me some generalization is needed. I think it is better to have a
linalg_ext.maxk
operation, both the above ops would be instances of that op. The op would be something likeFor the case of
topk
proper (the first instance above) you could uselinalg_ext.index
to get the position along the reduction dimension (or some other op, likelinalg_ext.reduction_dim
that is a placeholder for the position along the reduction dim, butlinalg_ext.index
seems good enough for this case). Both instances above can be represented using this op, they just have different bodies.FYI @KoolJBlack @ThomasRaoux (also @hanhanW ). Thanks Kojo and Thomas for the initial exploration on
topk
op, it was extremely helpful.