iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.84k stars 612 forks source link

`linalg_ext.topk` needs to be more general than the current definition. #9256

Open MaheshRavishankar opened 2 years ago

MaheshRavishankar commented 2 years ago

Based on review of PR #9162 it seems to me that the same linalg_ext.topk is being used in two separate contexts. To me that is an indication that linalg_ext.topk is very specific and needs to be generalized.

Background

Current definition of linalg_ext.topk is

%result:2 = iree_linalg_ext.topk dimension(...) ins(%input_values : ....) outs(%output_values, %output_indices : ...) {
  ....
} -> (tensor<...>, tensor<...>)

In this incarnation, the %result#0 contains the top K values from %input_values (the definition of top provided by the region of the op). %result#1 provides the position of these top K values in %input_values along the dimension specified.

In another incarnation, the operation is

%result:2 = iree_linalg_ext.topk dimension(...) ins(%input_values, %input_indices : ....) outs(%output_values, %output_indices : ...) {
  ....
} -> (tensor<...>, tensor<...>)

%result#0 is same as above. %input_indices is the same size as %input_values. Here %result#1 contains values from %input_indices (instead of position in %input_values).

The second incarnation is needed for when the dimension (which is a reduction dimension) is distributed across threads, and tracking the original index during the merge step. The specification of this operation though is a bit strange (IMO). It seems to suggest that the op-definition is being massaged to fit the eventual need, but the computation itself does not seem to represent what one would expect from a topk.

Suggestion

It seems to me some generalization is needed. I think it is better to have a linalg_ext.maxk operation, both the above ops would be instances of that op. The op would be something like

%result:n = linalg_ext.maxk dimension(...)
    ins(%input0, %input1, %input2, ...., %inputm : ....)
    outs(%output0, %output1,... %outputn : ...) {
^bb0(%arg0, %arg1, %arg2...., %argm, %argmp1, %argmp2, ... %argmpn) :
  ...
  linalg_ext.yield %t0, %t1, ... %tn : ....
}

For the case of topk proper (the first instance above) you could use linalg_ext.index to get the position along the reduction dimension (or some other op, like linalg_ext.reduction_dim that is a placeholder for the position along the reduction dim, but linalg_ext.index seems good enough for this case). Both instances above can be represented using this op, they just have different bodies.

FYI @KoolJBlack @ThomasRaoux (also @hanhanW ). Thanks Kojo and Thomas for the initial exploration on topk op, it was extremely helpful.

hanhanW commented 2 years ago

Big +1 on generalizing the op! Thanks for the exploration. Now I understand better what's the difference between Linalg ops and linalg_ext.top_k ops.

The main differences are:

  1. There are memref loads before scf ops.
  2. The scf.for ops yield values in linalg_ext.top_k ops.

We're not able to represent such ops in Linalg. Because Linalg aims to have perfect nested loops.

func.func @_topk_1d_dim0_max_dispatch_0() {
  %c0 = arith.constant 0 : index
  %c1 = arith.constant 1 : index
  %c10 = arith.constant 10 : index
  %c3 = arith.constant 3 : index
  %c64 = arith.constant 64 : index
  %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64) : memref<10xf32>
  memref.assume_alignment %0, 64 : memref<10xf32>
  %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(64) : memref<10xi32>
  memref.assume_alignment %1, 64 : memref<10xi32>
  %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(64) : memref<3xf32>
  memref.assume_alignment %2, 64 : memref<3xf32>
  %3 = hal.interface.binding.subspan set(0) binding(3) type(storage_buffer) offset(%c64) alignment(64) : memref<3xi32>
  memref.assume_alignment %3, 64 : memref<3xi32>
  scf.for %arg0 = %c0 to %c10 step %c1 {
    %4 = memref.load %0[%arg0] : memref<10xf32>
    %5 = memref.load %1[%arg0] : memref<10xi32> // This will be `%arg0` if input_indices is not specified.
    %6:2 = scf.for %arg1 = %c0 to %c3 step %c1 iter_args(%arg2 = %4, %arg3 = %5) -> (f32, i32) {
      %7 = memref.load %2[%arg1] : memref<3xf32>
      %8 = memref.load %3[%arg1] : memref<3xi32>
      %9 = arith.cmpf ogt, %arg2, %7 : f32
      %10 = arith.cmpf ogt, %7, %arg2 : f32
      %11 = arith.cmpi eq, %9, %10 : i1
      %12 = arith.cmpi slt, %arg3, %8 : i32
      %13 = arith.andi %11, %12 : i1
      %14 = arith.ori %9, %13 : i1
      %15 = arith.select %9, %arg2, %7 : f32
      %16 = arith.select %14, %arg3, %8 : i32
      memref.store %15, %2[%arg1] : memref<3xf32>
      memref.store %16, %3[%arg1] : memref<3xi32>
      %17 = arith.select %9, %7, %arg2 : f32
      %18 = arith.select %14, %8, %arg3 : i32
      scf.yield %17, %18 : f32, i32
    }
  }
  return
}

The needs are:

  1. Initial values for scf.for op. In this case, they are input_values and input_indices.
  2. The body of computation.

I borrowed the idea from Mahesh and extend it a bit more. To generalize the op, we can ask linalg_ext.yield to yield #output + #input values. The leading #output values will be stored into outs, and the rest of values will be yield by scf.yield.

%result:n = linalg_ext.maxk dimension(...)
    ins(%input0, %input1, %input2, ...., %inputm : ....)
    outs(%output0, %output1,... %outputn : ...) {
^bb0(%arg0, %arg1, %arg2...., %argm, %argmp1, %argmp2, ... %argmpn) :
  ...
  linalg_ext.yield %t0, %t1, ... %tn, %y0, %y1, ..., %ym : ....
}

The approach does not satisfy the needs for marking indices optional. But it generalizes it to other cases. We'll be able to use it in ArgMin/Max lowering as well. (In this context, we no longer need stack allocation for ArgMin/Max.)

I don't have a good solution for indices at this moment. Maybe having a new block could address the issue. But it makes the op a bit more complicated. E.g.,

%result:2 = linalg_ext.maxk dimension(...)
    ins(%input0: ....)
    outs(%output0, %output1 : ...) {
^bb0(%in):
  %idx = linalg_ext.index ...
  linalg_ext.yield %in, %idx ...  // these will be passed to scf.for iters.
^bb1(%arg0, %arg1, %argmp0, %argmp1) :
  ...
  linalg_ext.yield %t0, %t1, %y0, %y1: .... // The first two will be stored into outs, and the rest will be used in the next iteration.
}
MaheshRavishankar commented 2 years ago

Thinking a little bit more about this, there is an advantage to making linalg_ext.maxk an interface and linalg_ext.topk an op that implements this interface (maxk doesnt have to be an interface to start with but could just be the base op in Tablegen for something with low overhead). The reason for this is that linalg_ext.topk when tiled + distributed will need a different operation as to represent the merge. That operation can be a linalg_ext.maxk. Knowing the op is a linalg_ext.topk (as opposed to recognizing the operation is a linalg_ext.topk by looking at the op region) is much easier to manage.

To some extent this analogy holds : linalg_ext.topk is to linalg_ext.maxk like linalg.matmul is to linalg.generic. In other words ,if linalg_ext.maxk were an interface similar to LinalgOp, linalg_ext.topk would be a named op.

So lets stay the course with linalg_ext.topk for now to learn more about how this all shakes out.