iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.84k stars 614 forks source link

Optimize `scf.index_switch`/`cf.switch` for known unreachable default cases. #19036

Open benvanik opened 1 week ago

benvanik commented 1 week ago

When multi-targeting on a single device each dispatch gets a switch on the chosen executable variant. We know when we emit the scf.index_switch (which later becomes cf.switch) that at load-time we verified at least one of the cases will be reached. There's currently no way to communicate this to those ops, though, so we end up generating some pretty terrible code.

This is what 99.9% of the switches look like (as the workgroup counts and parameters often don't differ):

    %0 = ...
    scf.index_switch %0 
    case 0 {
      %__device_a_executable_0_iree_run_module_multi_linked = util.global.load immutable @__device_a_executable_0_iree_run_module_multi_linked : !hal.executable
      hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%__device_a_executable_0_iree_run_module_multi_linked : !hal.executable)[%c1] workgroups([%c1, %c1, %c1]) bindings([
        (%c2 : index)[%c0, %c64], 
        (%c1 : index)[%c0, %c16]
      ]) flags("None")
      scf.yield
    }
    case 1 {
      %__device_a_executable_0_iree_run_module_multi_linked = util.global.load immutable @__device_a_executable_0_iree_run_module_multi_linked : !hal.executable
      hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%__device_a_executable_0_iree_run_module_multi_linked : !hal.executable)[%c1] workgroups([%c1, %c1, %c1]) bindings([
        (%c2 : index)[%c0, %c64], 
        (%c1 : index)[%c0, %c16]
      ]) flags("None")
      scf.yield
    }
    default {
    }

This lowers into cf:

  %0 = ...
  %1 = arith.index_cast %0 : index to i32
  cf.switch %1 : i32, [
    default: ^bb3,
    0: ^bb1,
    1: ^bb2
  ]
^bb1:  // pred: ^bb0
  %__device_a_executable_0_iree_run_module_multi_linked = util.global.load immutable @__device_a_executable_0_iree_run_module_multi_linked : !hal.executable
  hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%__device_a_executable_0_iree_run_module_multi_linked : !hal.executable)[%c0] workgroups([%c1, %c1, %c1]) bindings([
    (%c0 : index)[%c0, %c16], 
    (%c1 : index)[%c0, %c16]
  ]) flags("None")
  cf.br ^bb4
^bb2:  // pred: ^bb0
  %__device_a_executable_0_iree_run_module_multi_linked_0 = util.global.load immutable @__device_a_executable_0_iree_run_module_multi_linked : !hal.executable
  hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%__device_a_executable_0_iree_run_module_multi_linked_0 : !hal.executable)[%c0] workgroups([%c1, %c1, %c1]) bindings([
    (%c0 : index)[%c0, %c16], 
    (%c1 : index)[%c0, %c16]
  ]) flags("None")
  cf.br ^bb4
^bb3:  // pred: ^bb0
  cf.br ^bb4
^bb4:  // 3 preds: ^bb1, ^bb2, ^bb3

Which eventually gets folded due to the duplication:

    %0 = ...
    %1 = arith.index_cast %0 : index to i32
    cf.switch %1 : i32, [
      default: ^bb2,
      0: ^bb1,
      1: ^bb1
    ]
  ^bb1:  // 2 preds: ^bb0, ^bb0
    hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%__device_a_executable_0_iree_run_module_multi_linked : !hal.executable)[%c1] workgroups([%c1, %c1, %c1]) bindings([
      (%c2 : index)[%c0, %c64], 
      (%c1 : index)[%c0, %c16]
    ]) flags("None")
    cf.br ^bb2
  ^bb2:  // 2 preds: ^bb0, ^bb1

That cf.switch default case is not reachable and so the op is longer required and we should be folding it away:

    // What we need a canonicalizer to produce:
    cf.br ^bb1
  ^bb1:  // 2 preds: ^bb0, ^bb0
    hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%__device_a_executable_0_iree_run_module_multi_linked : !hal.executable)[%c1] workgroups([%c1, %c1, %c1]) bindings([
      (%c2 : index)[%c0, %c64], 
      (%c1 : index)[%c0, %c16]
    ]) flags("None")
    cf.br ^bb2
  ^bb2:

Which will go away in block simplification:

    // What we need:
    hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%__device_a_executable_0_iree_run_module_multi_linked : !hal.executable)[%c1] workgroups([%c1, %c1, %c1]) bindings([
      (%c2 : index)[%c0, %c64], 
      (%c1 : index)[%c0, %c16]
    ]) flags("None")

There's a few approaches we could take. One would be to change the scf/cf ops to avoid the need for defaults (perhaps only in cases where there's no returns) - we'd just emit the op with no default region. That'd help all code emitting the ops that know it when they emit it but has caveats if the information only becomes available later on (via const eval, IPO, specialization, etc).

Another approach would be to insert a util.unreachable op in the default region of the scf.index_switch and then have a CFG cleanup that walks back from any block with that to a predecessor and fixes up the terminator to make it go the other direction. That also would be applicable to a lot of other parts of the codebase and is a good lowering target for assume ops/int range analysis.

If we did util.unreachable we could also support optimizations with the new util.assume.int ops - in this case for example if int range analysis or something else added an assume op:

    %0 = ...
    %assumed_range = util.assume.int %0<umin = 0, umax = 1> : index
    scf.index_switch %assumed_range
    case 0 {}
    case 1 {}
    default {}

a pass could insert the util.unreachable:

    %0 = ...
    %assumed_range = util.assume.int %0<umin = 0, umax = 1> : index
    scf.index_switch %assumed_range
    case 0 {}
    case 1 {}
    default {
      util.unreachable
    }

And then that'll get removed later on.

mvvsmk commented 6 days ago

Hey @benvanik  I would like to work on this. Could you assign this to me?

I would like to go with the second approach that you pointed out :

Another approach would be to insert a util.unreachable op in the default region of the scf.index_switch and then have a CFG cleanup that walks back from any block with that to a predecessor and fixes up the terminator to make it go the other direction. That also would be applicable to a lot of other parts of the codebase and is a good lowering target for assume ops/int range analysis.

I will implement this like a match and rewrite pattern in similar lines of IndexSwitchToIfPattern and MergeIndexSwitchPattern in the commit here.

I will create some synthetic test cases like the ones mentioned in the issue and the commit referenced above, but if you have a test case from an ML model that you would like me to look at, feel free to include them.