iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.83k stars 611 forks source link

Tensor-to-primitive/detensoring HLO conversion #1159

Closed benvanik closed 2 years ago

benvanik commented 4 years ago

HLO can only represent tensors, meaning that values that should not be tensors are still wrapped in them and operated on just as if they were real dense data. This is most easily visible with loops and conditionals, where the loop iterator initialization, increment, and condition are all in tensors. This results in host readbacks and a bunch of other extraneous work when really these should just be modeled as simple primitive values (i32/index, etc).

For example, this input loop:

  func @main() -> tensor<i32> attributes {iree.module.export} {
    %cst = constant dense<1> : tensor<i32>
    %cst_0 = constant dense<3> : tensor<i32>
    %0 = iree.do_not_optimize(%cst) : tensor<i32>
    %1 = iree.do_not_optimize(%cst_0) : tensor<i32>
    %2 = "xla_hlo.while"(%0) ( {
    ^bb0(%arg0: tensor<i32>):   // no predecessors
      %3 = "xla_hlo.compare"(%arg0, %1) {comparison_direction = "LT"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
      "xla_hlo.return"(%3) : (tensor<i1>) -> ()
    },  {
    ^bb0(%arg0: tensor<i32>):   // no predecessors
      %3 = xla_hlo.add %arg0, %arg0 : tensor<i32>
      "xla_hlo.return"(%3) : (tensor<i32>) -> ()
    }) : (tensor<i32>) -> tensor<i32>
    return %2 : tensor<i32>
  }

is turned into the following CFG:

func @main() -> tensor<i32> attributes {iree.module.export} {
  %cst = constant dense<1> : tensor<i32>
  %cst_0 = constant dense<3> : tensor<i32>
  %0 = iree.do_not_optimize(%cst) : tensor<i32>
  %1 = iree.do_not_optimize(%cst_0) : tensor<i32>
  br ^bb1(%0 : tensor<i32>)
^bb1(%2: tensor<i32>):  // 2 preds: ^bb0, ^bb2
  %3 = "xla_hlo.compare"(%2, %1) {comparison_direction = "LT"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
  %4 = extract_element %3[] : tensor<i1>
  cond_br %4, ^bb2(%2 : tensor<i32>), ^bb3(%2 : tensor<i32>)
^bb2(%5: tensor<i32>):  // pred: ^bb1
  %6 = xla_hlo.add %5, %5 : tensor<i32>
  br ^bb1(%6 : tensor<i32>)
^bb3(%7: tensor<i32>):  // pred: ^bb1
  return %7 : tensor<i32>
}

Which then after lowering through flow has the condition dispatched to the device and the condition read back (via flow.tensor.load):

func @main() -> tensor<i32> attributes {iree.module.export} {
  %cst = constant dense<1> : tensor<i32>
  %cst_0 = constant dense<3> : tensor<i32>
  %cst_1 = constant dense<1> : vector<3xi32>
  %0 = iree.do_not_optimize(%cst) : tensor<i32>
  %1 = iree.do_not_optimize(%cst_0) : tensor<i32>
  br ^bb1(%0 : tensor<i32>)
^bb1(%2: tensor<i32>):  // 2 preds: ^bb0, ^bb2
  %3 = flow.ex.stream.fragment(%arg0 = %cst_1 : vector<3xi32>, %arg1 = %2 : tensor<i32>, %arg2 = %1 : tensor<i32>) -> tensor<i1> {
    %8 = flow.dispatch @main_ex_dispatch_0::@main_ex_dispatch_0[%arg0 : vector<3xi32>](%arg1, %arg2) : (tensor<i32>, tensor<i32>) -> tensor<i1>
    flow.return %8 : tensor<i1>
  }
  %4 = flow.tensor.load %3 : tensor<i1>
  cond_br %4, ^bb2(%2 : tensor<i32>), ^bb3(%2 : tensor<i32>)
^bb2(%5: tensor<i32>):  // pred: ^bb1
  %6 = flow.ex.stream.fragment(%arg0 = %cst_1 : vector<3xi32>, %arg1 = %5 : tensor<i32>) -> tensor<i32> {
    %8 = flow.dispatch @main_ex_dispatch_1::@main_ex_dispatch_1[%arg0 : vector<3xi32>](%arg1) : (tensor<i32>) -> tensor<i32>
    flow.return %8 : tensor<i32>
  }
  br ^bb1(%6 : tensor<i32>)
^bb3(%7: tensor<i32>):  // pred: ^bb1
  return %7 : tensor<i32>
}

If instead we found these host-only values (even if only scalar tensors to start) we could run the whole loop in the host VM and avoid the readback.

This is also visible in inputs that have dynamic update slices (translated through to flow.tensor.update), where the current loop iterator value is needed to know where to update the tensor (mapping back to which timestep is being processed, etc). These updates need to be recorded into the command buffer on the host which means that we perform a readback to effectively compute an offset and then throw it back to the device.

Other enhancements around indirect dispatch and dynamic flow.tensor.update (#1160) will make some of these cases not so bad when device->host data dependencies really do exist, however if we can remove all of the trivial cases without relying on that we'll have much more readable IR and much lower overhead at runtime.

GMNGeoffrey commented 3 years ago

I'm hitting an issue that I think is related here that actually results in a compilation failure (as opposed to performance issues) because IREE doesn't handle i1 tensors well at all: https://github.com/google/iree/issues/3102#issuecomment-963409877

ScottTodd commented 3 years ago

I'm hitting an issue that I think is related here that actually results in a compilation failure (as opposed to performance issues) because IREE doesn't handle i1 tensors well at all: #3102 (comment)

Does -iree-flow-enable-linalg-detensorize help there? It still isn't enabled by default. (There are many things overlapping here and we should try to keep the issues focused...)

GMNGeoffrey commented 3 years ago

It does not, unfortunately

Also, unrelated, the only mention I see of that flag in the codebase is in its definition. It should probably have some tests...

ScottTodd commented 3 years ago

Also, unrelated, the only mention I see of that flag in the codebase is in its definition. It should probably have some tests...

Well yeah :P I spent about three weeks trying to generate/find representative test cases and didn't get very far. Any models I tried to compile hit frontend issues, either in their own frameworks or in TF->MLIR. The MLIR code has unit test coverage in MLIR core and is mostly an optimization in IREE itself. I have been wanting to flip the flag (https://github.com/google/iree/pull/6863) to enable coverage in the few tests (if, while, collatz, etc.) that are affected.

allieculp commented 2 years ago

@GMNGeoffrey @ScottTodd Stale P1 item here, please take a look when you get a chance and update or deprioritize!

ScottTodd commented 2 years ago

I think @rsuderman was going to look at getting this work over the finish line.