iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.56k stars 572 forks source link

Upstream improvements for scf.for folding on loop ranges. #5547

Open benvanik opened 3 years ago

benvanik commented 3 years ago

Motivating example:

    func @simple_mul_dispatch_0(%arg0: !vmvx.interface, %arg1: !vmvx.buffer, %arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: index, %arg7: index, %arg8: index, %arg9: index, %arg10: index) {
      %c0 = constant 0 : index
      %c2 = constant 2 : index
      %c1 = constant 1 : index
      %c-1 = constant -1 : index
      %c4 = constant 4 : index
      %c4096 = constant 4096 : index
      %0 = vmvx.interface.binding<%arg0 : !vmvx.interface>[0] : !vmvx.buffer
      %1 = vmvx.interface.binding<%arg0 : !vmvx.interface>[1] : !vmvx.buffer
      %2 = vmvx.interface.binding<%arg0 : !vmvx.interface>[2] : !vmvx.buffer
      scf.for %arg11 = %arg3 to %c2 step %arg9 {
        %3 = muli %arg11, %c-1 : index
        %4 = addi %3, %c2 : index
        %5 = cmpi slt, %c1, %4 : index
        %6 = select %5, %c1, %4 : index
        scf.for %arg12 = %arg2 to %c4096 step %arg8 {
          %7 = muli %arg12, %c-1 : index
          %8 = addi %7, %c4096 : index
          %9 = cmpi slt, %c1, %8 : index
          %10 = select %9, %c1, %8 : index
          scf.for %arg13 = %c0 to %6 step %c1 {
            %11 = addi %arg11, %arg13 : index
            %12 = muli %11, %c4096 : index
            scf.for %arg14 = %c0 to %10 step %c1 {
              %13 = addi %arg12, %arg14 : index
              %14 = addi %12, %13 : index
              %15 = muli %14, %c4 : index
              %16 = vmvx.buffer.load<%0 : !vmvx.buffer>[%15] : i32
              %17 = vmvx.buffer.load<%1 : !vmvx.buffer>[%15] : i32
              %18 = muli %16, %17 : i32
              vmvx.buffer.store<%2 : !vmvx.buffer>[%15], %18 : i32
            }
          }
        }
      }
      return
    }

There's a lot of muls and adds in here that could be moved up into the loop ranges. For example, %13 should be able to fold into the %arg14 loop:

%arg12p10 = addi %arg12, %10
scf.for %arg14 = %arg12 to %arg12p10 step %c1 {
  %14 = addi %12, %arg14 : index
  %15 = muli %14, %c4 : index
  %16 = vmvx.buffer.load<%0 : !vmvx.buffer>[%15] : i32
  %17 = vmvx.buffer.load<%1 : !vmvx.buffer>[%15] : i32
  %18 = muli %16, %17 : i32
  vmvx.buffer.store<%2 : !vmvx.buffer>[%15], %18 : i32
}

Which then repeated folds again:

%arg12p10 = addi %arg12, %10
%arg12p12 = addi %arg12, %12
%arg12p10p12 = addi %arg12p10, %12
scf.for %arg14 = %arg12p12 to %arg12p10p12 step %c1 {
  %15 = muli %arg14, %c4 : index
  %16 = vmvx.buffer.load<%0 : !vmvx.buffer>[%15] : i32
  %17 = vmvx.buffer.load<%1 : !vmvx.buffer>[%15] : i32
  %18 = muli %16, %17 : i32
  vmvx.buffer.store<%2 : !vmvx.buffer>[%15], %18 : i32
}

And then the muli can be folded in too to scale the range:

%arg12p10 = addi %arg12, %10
%arg12p12 = addi %arg12, %12
%arg12p10p12 = addi %arg12p10, %12
%arg12p10p12x4 = muli %arg12p10p12, %c4
scf.for %arg14 = %arg12p12 to %arg12p10p12x4 step %c4 {
  %16 = vmvx.buffer.load<%0 : !vmvx.buffer>[%arg14] : i32
  %17 = vmvx.buffer.load<%1 : !vmvx.buffer>[%arg14] : i32
  %18 = muli %16, %17 : i32
  vmvx.buffer.store<%2 : !vmvx.buffer>[%arg14], %18 : i32
}

(and then of course all those hoisted values can be themselves canonicalized/cse'd/hoisted again/etc)

This will help reduce a large amount of the addressing math by folding it into math that needed to happen anyway (loop variables) or at least help get a lot of it to be invariant to the inner loop and able to be hoisted out.

These could be canonicalizations on scf.for or a set of patterns that could be applied independently (there seems to be a bit of both approaches used upstream).

This is critical to VMVX but would also help the VM when we start using SCF there as an input (as well as any upstream dialect usage on the other codegen paths/etc).

anthonycanino1 commented 3 years ago

Hi @benvanik , I'm interested in tacking a crack at this. I've been following the project for a little while and this looks like a good first issue.

Can you clarify this statement for me...

These could be canonicalizations on scf.for or a set of patterns that could be applied independently (there seems to be a bit of both approaches used upstream).

Are these other transformations done in IREE that do something similar to the above, or are you referring to the way the transformation is applied?

benvanik commented 3 years ago

Hi @anthonycanino1! (sorry for the delay - was on vacation)

scf.for has some canonicalization patterns here: https://github.com/llvm/llvm-project/blob/4184018253e720b0f2449b2b83ce27fc682f8579/mlir/lib/Dialect/SCF/SCF.cpp#L841-L845

The idea would be to add some new ones that perform the range folding and hoisting like in the examples above. ForOpIterArgsFolder does range manipulation but for doing direct folds (the entire scf.for returns a single value, etc), instead of the range arithmetic manipulation like called for here.

scf.if has some similar patterns that may be useful as examples too: https://github.com/llvm/llvm-project/blob/4184018253e720b0f2449b2b83ce27fc682f8579/mlir/lib/Dialect/SCF/SCF.cpp#L841-L845

If added upstream in those files then everyone benefits, which would be fantastic! The alternative mentioned is to do it somewhere in IREE as a pass that applies some patterns, however that'd only be if what we wanted to do was unsafe and required some conditions that were IREE-specific (using IREE dialect types, etc). I think these are all general purpose and safe, though.

As far as I know there isn't any active work on this upstream, but you could ask in the MLIR discord or discourse to see if anyone would want to collaborate.

anthonycanino1 commented 3 years ago

No problem @benvanik! Thanks for the tips, I'm digging into the code you pointed out, and I'll look into chatting with the MLIR folks and working on this upstream.

anthonycanino1 commented 3 years ago

@benvanik I have some small proof of concept of this code that works on your example above, and I'm starting to see how it behaves in the larger MLIR test suite now.

Can you provide me with a little more context of when this optimization will be useful, so I can start the discussion with the MLIR folks? Does this pattern of add/mul index offsets appears on the induction variable when compiling from linalg on tensors/memrefs?

anthonycanino1 commented 3 years ago

A pass for this has landed upstream in https://github.com/llvm/llvm-project/commit/3f429e82d3ea1710ee0a841675acba9bb7b658d2.

Please let me know how you plan to work it in to the VMVX pipeline. I can even help with it as well if you give some tips.

benvanik commented 3 years ago

Nice work @anthonycanino1!

After our next llvm integrate we should be able to add it here: https://github.com/google/iree/blob/552d3f888e46f96c8227b1bd2e4600b85338237b/iree/compiler/Dialect/VM/Transforms/Passes.cpp#L24

As well as in the various other codegen lowering pipelines (as they all at some point route through scf). It looks like prior to the existing createForOpCanonicalizationPass in iree/compiler/Conversion/LinalgToLLVM and LinalgToSPIRV would be good. (that pass looks like it should eventually be upstreamed and renamed ForOpVectorCanonicalizationPass as it's just optimizing vector variables)

There may also be places higher up in the stack where it's useful (as this kind of optimization can be applied many times, like canonicalization or CSE) that we could experiment with. Making the IR more readable and giving more opportunity for canonicalizers/pattern matchers to identify similar loop constructs thanks to the folding will have some good ripple effects throughout the whole stack.

benvanik commented 2 years ago

Unfortunately the upstream pass causes miscompilation and incorrect results ;( I'm going to close this issue for now until we want to do a big push on improving things and fix the upstream dialects.

stellaraccident commented 2 years ago

Any value in keeping it open with breadcrumbs for upstream fixes needed?

benvanik commented 2 years ago

Yeah, I'll stash it in the backlog. I feel like we could get 2x vmvx improvements from some basic work here (and those would help everything else) but it needs some investment.

stellaraccident commented 2 years ago

Sent out a flag to de-specialize FuncOp SCF stuff: https://reviews.llvm.org/D128614

stellaraccident commented 2 years ago

Also de-specialized Affine passes: https://reviews.llvm.org/D128616

stellaraccident commented 2 years ago

Iterating on a repro but can't see the illegal thing. I did bisect the other optimization passes and was able to repro only with this one, so I don't believe this is an interaction thing (I think this pass is faulty). Here is the before/after output of two dispatches for a simple add that is failing when SCFForLoopRangeFolding is enabled: https://gist.github.com/stellaraccident/f88a28cb935cf9f12efffae8cb8fdcb9

stellaraccident commented 2 years ago

Ugh: https://github.com/llvm/llvm-project/issues/56235

benvanik commented 2 years ago

Damn my short-term memory; was just bit by this again: "ew all this indexing math" -> "oh neat a range folding pass" -> "whoops that doesn't work..." 🤦

The majority of the loops I see ending up in common vmvx programs are of a form that could be handled by the pass (I see a note about negative steps but I'm fine with ignoring those), so it'd still be good to have. Now that there's some dataflow stuff upstream it may also be possible to do something nicer (tracking values across control edges and such and using value range analysis to avoid divs/rems/etc).