iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.84k stars 611 forks source link

Tensor-to-primitive/detensoring HLO conversion #1159

Closed benvanik closed 2 years ago

benvanik commented 4 years ago

HLO can only represent tensors, meaning that values that should not be tensors are still wrapped in them and operated on just as if they were real dense data. This is most easily visible with loops and conditionals, where the loop iterator initialization, increment, and condition are all in tensors. This results in host readbacks and a bunch of other extraneous work when really these should just be modeled as simple primitive values (i32/index, etc).

For example, this input loop:

  func @main() -> tensor<i32> attributes {iree.module.export} {
    %cst = constant dense<1> : tensor<i32>
    %cst_0 = constant dense<3> : tensor<i32>
    %0 = iree.do_not_optimize(%cst) : tensor<i32>
    %1 = iree.do_not_optimize(%cst_0) : tensor<i32>
    %2 = "xla_hlo.while"(%0) ( {
    ^bb0(%arg0: tensor<i32>):   // no predecessors
      %3 = "xla_hlo.compare"(%arg0, %1) {comparison_direction = "LT"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
      "xla_hlo.return"(%3) : (tensor<i1>) -> ()
    },  {
    ^bb0(%arg0: tensor<i32>):   // no predecessors
      %3 = xla_hlo.add %arg0, %arg0 : tensor<i32>
      "xla_hlo.return"(%3) : (tensor<i32>) -> ()
    }) : (tensor<i32>) -> tensor<i32>
    return %2 : tensor<i32>
  }

is turned into the following CFG:

func @main() -> tensor<i32> attributes {iree.module.export} {
  %cst = constant dense<1> : tensor<i32>
  %cst_0 = constant dense<3> : tensor<i32>
  %0 = iree.do_not_optimize(%cst) : tensor<i32>
  %1 = iree.do_not_optimize(%cst_0) : tensor<i32>
  br ^bb1(%0 : tensor<i32>)
^bb1(%2: tensor<i32>):  // 2 preds: ^bb0, ^bb2
  %3 = "xla_hlo.compare"(%2, %1) {comparison_direction = "LT"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
  %4 = extract_element %3[] : tensor<i1>
  cond_br %4, ^bb2(%2 : tensor<i32>), ^bb3(%2 : tensor<i32>)
^bb2(%5: tensor<i32>):  // pred: ^bb1
  %6 = xla_hlo.add %5, %5 : tensor<i32>
  br ^bb1(%6 : tensor<i32>)
^bb3(%7: tensor<i32>):  // pred: ^bb1
  return %7 : tensor<i32>
}

Which then after lowering through flow has the condition dispatched to the device and the condition read back (via flow.tensor.load):

func @main() -> tensor<i32> attributes {iree.module.export} {
  %cst = constant dense<1> : tensor<i32>
  %cst_0 = constant dense<3> : tensor<i32>
  %cst_1 = constant dense<1> : vector<3xi32>
  %0 = iree.do_not_optimize(%cst) : tensor<i32>
  %1 = iree.do_not_optimize(%cst_0) : tensor<i32>
  br ^bb1(%0 : tensor<i32>)
^bb1(%2: tensor<i32>):  // 2 preds: ^bb0, ^bb2
  %3 = flow.ex.stream.fragment(%arg0 = %cst_1 : vector<3xi32>, %arg1 = %2 : tensor<i32>, %arg2 = %1 : tensor<i32>) -> tensor<i1> {
    %8 = flow.dispatch @main_ex_dispatch_0::@main_ex_dispatch_0[%arg0 : vector<3xi32>](%arg1, %arg2) : (tensor<i32>, tensor<i32>) -> tensor<i1>
    flow.return %8 : tensor<i1>
  }
  %4 = flow.tensor.load %3 : tensor<i1>
  cond_br %4, ^bb2(%2 : tensor<i32>), ^bb3(%2 : tensor<i32>)
^bb2(%5: tensor<i32>):  // pred: ^bb1
  %6 = flow.ex.stream.fragment(%arg0 = %cst_1 : vector<3xi32>, %arg1 = %5 : tensor<i32>) -> tensor<i32> {
    %8 = flow.dispatch @main_ex_dispatch_1::@main_ex_dispatch_1[%arg0 : vector<3xi32>](%arg1) : (tensor<i32>) -> tensor<i32>
    flow.return %8 : tensor<i32>
  }
  br ^bb1(%6 : tensor<i32>)
^bb3(%7: tensor<i32>):  // pred: ^bb1
  return %7 : tensor<i32>
}

If instead we found these host-only values (even if only scalar tensors to start) we could run the whole loop in the host VM and avoid the readback.

This is also visible in inputs that have dynamic update slices (translated through to flow.tensor.update), where the current loop iterator value is needed to know where to update the tensor (mapping back to which timestep is being processed, etc). These updates need to be recorded into the command buffer on the host which means that we perform a readback to effectively compute an offset and then throw it back to the device.

Other enhancements around indirect dispatch and dynamic flow.tensor.update (#1160) will make some of these cases not so bad when device->host data dependencies really do exist, however if we can remove all of the trivial cases without relying on that we'll have much more readable IR and much lower overhead at runtime.

benvanik commented 3 years ago

This is still critically important and will become moreso next quarter. Tagging @silvasean as we've talked about it before in case he knows someone who may be willing to work on this (it'd be useful upstream in MLIR for IREE as well as anything routing through MHLO/TOSA/etc).

ergawy commented 3 years ago

Hi @benvanik , I have been contributing to MLIR (specially the spv dialect, with a lot of help from @antiagainst 😄) for a while (https://reviews.llvm.org/people/revisions/14762/). I would love to contribute to IREE as well.

Is this issue still something that needs to be picked up?

I had a quick look and got the gist of what needs to be done. If it's still fine work on that, I will inspect it in more detail.

benvanik commented 3 years ago

Yeah! It'd be useful to sketch out the full pipeline of MHLO|TOSA -> linalg-on-tensors -> IREE's flow and see where this may be able to be upstreamed (if at all). Brainstorming a bit:


Since we are migrating to linalg-on-tensors as our input dialect over the next few months it's possible that something could be done there which would let such an analysis/pass happen after MHLO/TOSA/etc such that it can be shared across all of them. This may also help solve a lot of the corner cases and avoid needing to specialize things to any other input dialect - such as handling broadcasting behavior. Instead of having specific patterns for converting mhlo.add to std.addi, for example, we let the existing patterns for linalg do that and instead perform the simplification on linalg.generic/indexed_generic ops (and those may literally be the only two ops we deal with, which is why it would be so nice!).

The major work that then happens could be a parameterized analysis pass that goes through IR and determines which values chains should be detensored. The pass would be looking for which values flowed between other tensor ops vs. those that were extracted out into scalars, and maybe in the future have parameters like "limit detensoring to ops producing a tensor of <= 4/16/32 elements" (so we could pull out vector<4xi32> for example). If a value is determined to be detensored the conversion should be fairly straightforward then and may just be the addition of a helper method in linalg that takes the body of a linalg.generic/indexed_generic and clones it out.

For example:

%A = ...
%B = ...
%tensor = linalg.generic ins(%A, %B : tensor<i32>, tensor<i32>) -> tensor<i32> {
 ^bb0(%a: i32, %b: i32) :
  %c = addi %a, %b: i32
  linalg.yield %c : i32
}
%scalar = extract_element %tensor[] : tensor<i32>

The analysis would look at the extract_element's source %tensor, note that it was a candidate for detensoring (single element, etc), and then clone out the op:

%A = ...
%B = ...
%a = extract_element %A[] : tensor<i32>
%b = extract_element %B[] : tensor<i32>
%scalar = addi %a, %b: i32

The nice thing here is that if that was applied greedily then if whatever was producing %A or %B could be detensored then the next step would do that too all the way up the chain. If the transformation was done cloning then if there were other users of the value that couldn't be detensored the op would still remain for those uses (if %tensor was used by other actual tensor operations/etc).

This may end up just being some canonicalizations on linalg ops looking for this kind of thing - then the greedy application is taken care of by MLIR's canonicalization pass. For example, a pattern on linalg.generic that checked to see if any user of the result was extract_element and if so cloned the body out/updated usage may be the entirety of it! If more control was desired then mlir::applyOpPatternsAndFold would let this happen outside of canonicalization (useful for example if we wanted to use analysis results).


An alternative to all of this is to do something specific to HLO - which has the advantage of being something that would work today in IREE but is more difficult to do as each of the dozens of MHLO ops would need to be supported. You'd still be performing the analysis but then instead of touching linalg ops you'd be mapping things like mhlo.add : tensor<i32> -> std.addi : i32 directly. It could also be hacked together under integrations/tensorflow/compiler/ before seeing how it may be upstreamed into tensorflow's repo.

This may even be useful as an experiment before doing something like the above linalg approach as it'll help identify where the transformation is useful and act as a verification for whether the linalg-based one is performing well. Iterating in the simpler scope here on what the analysis pass looks like before then working to upstream and apply it to linalg may help to get the design in a good state for broader review on MLIR's discourse.

Writing additional MHLO->std patterns in addition to the ones already being written for converting to linalg-on-tensors feels like a lot of work and code to maintain, though, so even if it takes a little time for IREE's linalg-on-tensor support to come online it could still help others looking to use linalg-on-tensors today and then be extremely reusable for all MLIR users in the future (even if they are not using IREE).

Happy to chat more about this on discord or here!

ergawy commented 3 years ago

Thanks a lot for the detailed response. This all sounds really exciting and I am eager to contribute.

I am a newbie so please excuse any possible misunderstanding/not-really-getting-it.

From your description, it seems that the more immediate course of action would be to do detensoring exclusively to HLO (the 2nd alternative above). This might not be as general and useful as doing so for linalg on tensors but would serve as a good experiment that can later on guide a similar effort but within linalg on tensors.

From my side, I don't mind working on something specific to HLO that might be reverted later and replaced with a more general/useful implementation. In any case, it will be a wonderful learning experience and I can help with "porting" those optimizations to linalg on tensors later on.

Based on that, my naive understanding is that the first concrete steps would be the following:

Does that make sense to you?

benvanik commented 3 years ago

Makes sense! Starting small with simple cases likely to be encountered from normal tensorflow code will help prevent the need for an exhaustive MHLO pattern mapping. For example, using the simple test linked in the description means you only need to support compare and add ops. If you take the tblgen approach like you mention then the initial work may mostly be setting up the new td files and build system goo :)

A larger, more realistic example would be https://github.com/google/iree/blob/473186214a19447932b02bc78d9cf47a332b2dc3/iree%2Ftest%2Fe2e%2Fmodels%2Funidirectional_lstm.mlir#L100 - probably a good goal for initial proof of concept as the source tf.while ends up as that mhlo.while and supporting that case will enable almost anything using tf.while to work.

A larger representative example will also help flush out any issues that may be discovered: I suspect there are cases where we'll need to carry along information through the CFG and getting some concrete examples will help figure out next steps. A particular gotcha will be cases where there's no unidirectional flow (so not host -> device or device -> host but instead device -> host -> device or some other kind of sandwich). Crossing that bridge when we come to it is fine and anything you were able to get together would be really valuable in informing that follow-on work!

antiagainst commented 3 years ago

Oh nice to see you here, @KareemErgawy-TomTom! :D Just wanted to say hi and thanks for considering contributing to IREE!

ergawy commented 3 years ago

Update:

Apologies for taking almost a month to provide an update. IREE is a bit overwhelming and it took me sometime to find my way around and get familiar with the code base a bit more.

I spent the past few weeks:

With that, I believe a somewhat good starting point would be:

I still need to workout the details of how the pass will be actually implemented. But at least this is how I believe an MHLO-specific solution might look like.

Let me know if you have any comments and hopefully the next update would be much quicker.

silvasean commented 3 years ago

My gut is that there are two parts to this:

  1. The actual conversion of all HLO's to corresponding scalar code (e.g. hlo.add -> std.addi or std.addf)
  2. The policy of which HLO's we should convert.

For 1., we are moving to a world of linalg-on-tensors, so you should make this pass exclusively work on linalg on tensors (not on HLO). One benefit of this is that the payload of a rank-0 linalg-on-tensors op will already contain ops like std.addi/addf, so you won't need to do any conversion there! You will literally just be inlining the linalg op payloads!

For 2., my intuition is an algorithm that first identifies connected components of the use-def graph which consist exclusively of rank-0 tensor code, and then a policy that determines for each connected component whether it makes sense to detensorize. (for the purpose of connected components, a block argument is considered connected to its corresponding successor arg in the predecessor blocks).

ergawy commented 3 years ago

Thanks @silvasean for the comment. I went through the available docs and talks that I could find to understand the current status of linalg-on-tensors.

I have a few questions:

  1. This might be trivial to you guys, but how would the syntax for a rank-0 linalg-on-tensors op (say linalg.generic) look like? I am trying to find an example that but couldn't find any and tried to play with some of the available examples and change them to work on rank-0 tensor but didn't manage to get any example to compile.
  2. Is there any flow currently implemented anywhere (IREE or anything else) that implements some kind of lowering that passes through linalg-on-tensors? As far as I understood, IREE didn't migrate to use linalg-on-tensors yet. This is not really crucial and I can start prototyping with any snippet that has 0-rank ops, but just curious.
  3. Is linalg-on-tensors the better place (as opposed to HLO) to prototype the detensoring logic at the moment? I mean in terms of maturity and completeness of implementation.
  4. This might be totally dumb question, but if the end goal is to convert linalg (or HLO) ops to std ops whenever possible/beneficial (i.e. as a start when such ops work on 0-rank tensors), what difference does it make if we work with linalg-on-tensors vs linalg-on-buffers? I mean wouldn't an analysis on memrefs also work? I didn't research the answer to question very thoroughly so I might be missing something very obvious here.

Edit: found the answer to the 4th question here: https://llvm.discourse.group/t/an-update-on-linalg-on-tensors/1878

silvasean commented 3 years ago
  1. Example:
#map = affine_map<() -> ()>
module  {
  func @tensor(%arg0: tensor<f32>, %arg1: tensor<f32>) attributes {iree.module.export} {
    %0 = linalg.init_tensor [] : tensor<f32>
    %1 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = []} ins(%arg0, %arg1 : tensor<f32>, tensor<f32>) outs(%0 : tensor<f32>) {
    ^bb0(%arg2: f32, %arg3: f32, %arg4: f32):  // no predecessors
      %2 = addf %arg2, %arg3 : f32
      linalg.yield %2 : f32
    } -> tensor<f32>
    return
  }
}
  1. @MaheshRavishankar @nicolasvasilache can give you pointers. There is enough working for you to start this work I think.

  2. For prototyping feel free to do it wherever you want. HLO might be slightly easier if you want to try to slot this into our existing pipelines for now (the bulk of the algorithm will be the same -- converting it to linalg-on-tensors will be a simplification). But we want this pass to work well for TOSA (and other frontends) and the way to do that is with linalg-on-tensors.

  3. At the place in the pipeline that we need to run this, there is only linalg on tensors. The analysis on memrefs would also be much more difficult.

ergawy commented 3 years ago

Update:

It would be great if you can have a look and provide any feedback that can push this further.

ScottTodd commented 3 years ago

Detensoring is becoming a priority for performance optimization work in IREE, so we're taking a closer look at integrating the upstream work into IREE (and possibly extending it as needed) now.

ergawy commented 3 years ago

Sorry for taking sooo much time for integrating detensoring. I just wrote a message to Sean on Discourse a few hours ago that I tried to follow the changes in IREE but since I don't have enough time, I couldn't follow the fast pace of how things are going.

I still would love to give this a try if someone from IREE can help me. I can provide more details about the current detensoring implementation in linalg-on-tensors and patch it whenever needed, and whoever can help me from the IREE team (maybe @ScottTodd) can give pointers on where to start.

If you have the time and energy to collaborate with me on this, it would be great.

ScottTodd commented 3 years ago

I'm just getting up to speed on this now, but here's what I'm trying:

It looks like detensoring is working how we'd want (at least for this case), but we have some IR that needs further folding to lower through the rest of our pipeline. One case from that gist is here:

  %6 = zexti %5 : i1 to i8
  %7 = tensor.from_elements %6 : tensor<1xi8>
  %8 = linalg.tensor_collapse_shape %7 [] : tensor<1xi8> into tensor<i8>
  %9 = flow.tensor.load %8 : tensor<i8>

IR like that is surviving all the way down to something like this:

  ^bb3(%13: i32):  // pred: ^bb1
    %14 = tensor.from_elements %13 : tensor<1xi32>
    %15 = flow.ex.stream.fragment(%14) : (tensor<1xi32>) -> tensor<i32> =
        (%arg2: tensor<1xi32>) -> tensor<i32> {
      %17 = flow.tensor.reshape %arg2 : tensor<1xi32> -> tensor<i32>
      flow.return %17 : tensor<i32>
    }
    %16 = hal.tensor.cast %15 : tensor<i32> -> !hal.buffer_view
    return %16 : !hal.buffer_view
  }

we shouldn't need to go i32 -> 1xi32 -> i32, and that tensor.from_elements should be turned into some flow.* op (or just be folded away)

benvanik commented 3 years ago

RE tensor.from_elements - it looks like I added a flow.tensor.splat that we could lower this to, then our flow.tensor.load canonicalization can just look for flow.tensor.splat and pull out the value (better than mixing dialects, and the flow.tensor.* ops are dynamic shape aware). flow.tensor.splat doesn't do anything yet (it should be lowered into a hal.command_buffer.fill in flow->hal conversion) but to start if we fold them all away it's fine to leave that TODO.

benvanik commented 3 years ago

(I don't know if I like where it's located - this should be under Dialect/Flow/Conversions/TensorToFlow/ - but https://github.com/google/iree/blob/cd247ca50cfeeb3ff0800aa7190bbeed94b1b70d/iree/compiler/InputConversion/Common/ConvertUpstreamToIREE.cpp is where the tensor.from_elements -> flow.tensor.splat conversion would live)

ergawy commented 3 years ago

I can work on these 2 changes that will clean-up the detensored code in the above example:

Unless someone else did that, I can do it quickly.

ScottTodd commented 3 years ago

I did start on those at https://github.com/ScottTodd/iree/commits/detensorize , and I have some IR notes in this gist. Feel free to work on those patterns, starting from my changes or from scratch :)

I can try finding/making a test model that will show the performance benefits from detensoring, since our open source coverage is a bit sparse in that area.

MaheshRavishankar commented 3 years ago

From discussion on Discord, seems like there are some linalg.generic tensorizations missing as well?

https://cdn.discordapp.com/attachments/689906000043573354/867108835838656522/unknown.png

benvanik commented 3 years ago

The linalg.tensor_collapse_shape folding would be a good one for upstream.

ScottTodd commented 3 years ago

From discussion on Discord, seems like there are some linalg.generic tensorizations missing as well?

https://cdn.discordapp.com/attachments/689906000043573354/867108835838656522/unknown.png

Here's a more controlled repro of that: https://gist.github.com/ScottTodd/e6b035b658a4b0e05a9336dba1dc8452

ergawy commented 3 years ago

@MaheshRavishankar The behavior in the example you posted is due to the specific cost model used be default. Right now, we have 2 models:

I think we can either:

Both are equivalent approaches. But I am not sure whether there is an argument against the idea in general. WDYT?

MaheshRavishankar commented 3 years ago

@MaheshRavishankar The behavior in the example you posted is due to the specific cost model used be default. Right now, we have 2 models:

  • PureControlFlowDetectionModel (used by default): which basically looks for detensorable ops involved in conditional branches and from that tries to discover loops and if-conditions. This is why the 2nd linalg.generic in the picture you posted is detensored while the others aren't.
  • AggressiveDetensoringModel (used only in test code at the moment): which looks for all detensorable ops (regardless of their involvement in loops or conditions and detensors them. I guess this model would have detected the other 2 linalg.generic ops in the example.

I think we can either:

  • Extend the first model to detensor ops whose results are used in branches.
  • Add a new model that detensors ops whose result are used in branches and combine it with the first model above. Both are equivalent approaches. WDYT?

Its cool that you have these models! That should be easy to control then. I think extending the first one to detensor ops whose results are used in branches makes sense to me right now.

benvanik commented 3 years ago

I think extending the first one to detensor ops whose results are used in branches makes sense to me right now.

Agreed! The most impact this will have to start is preventing readbacks on loops, so it'd be great to see that handled.

ScottTodd commented 3 years ago

The program in control_flow_test.py also shows how this will help: https://gist.github.com/ScottTodd/34523eb5850a0884c2f3c2e342ce69e3

image

stellaraccident commented 3 years ago

Nice. Some further work on 1d-1 element vs 0d should take care of the last bits.

On Wed, Jul 21, 2021, 11:55 AM Scott Todd @.***> wrote:

The program in control_flow_test.py https://github.com/google/iree/blob/main/integrations/tensorflow/e2e/control_flow_test.py also shows how this will help: https://gist.github.com/ScottTodd/34523eb5850a0884c2f3c2e342ce69e3

[image: image] https://user-images.githubusercontent.com/4010439/126543843-24a9358e-62e3-430c-bb64-2c2d2837a55e.png

— You are receiving this because you are subscribed to this thread. Reply to this email directly, view it on GitHub https://github.com/google/iree/issues/1159#issuecomment-884418198, or unsubscribe https://github.com/notifications/unsubscribe-auth/AADYVAFAIFKMO67HFWCVKWLTY4J3DANCNFSM4LPTE5PQ .

ScottTodd commented 3 years ago

Continuing to make progress here. With the refactoring in https://github.com/google/iree/pull/6586 and passes in https://github.com/google/iree/blob/main/iree/compiler/Dialect/Flow/Transforms/Passes.cpp like:

  passManager.addNestedPass<FuncOp>(createFusionOfTensorOpsPass());
  passManager.addNestedPass<FuncOp>(mlir::createCSEPass());
+ // NOTE: must run flow->tensor ops and canonicalization after detensorizing.
+ passManager.addNestedPass<FuncOp>(mlir::createLinalgDetensorizePass());
  passManager.addPass(memref::createResolveShapedTypeResultDimsPass());
  passManager.addNestedPass<FuncOp>(
      IREE::Flow::createConvertToFlowTensorOpsPass(
          /*runBeforeDispatchRegionFormation=*/true));
+ passManager.addNestedPass<FuncOp>(mlir::createCanonicalizerPass());
  passManager.addNestedPass<FuncOp>(
      IREE::Flow::createDispatchLinalgOnTensorsPass());

Test model status:


The Collatz model ends up with IR like this (full IR dump here):

// -----// IR Dump After LinalgDetensorize //----- //
...
^bb5(%28: f32):  // pred: ^bb1
  %29 = tensor.from_elements %28 : tensor<1xf32>
  %30 = linalg.tensor_collapse_shape %29 [] : tensor<1xf32> into tensor<f32>
  %31 = hal.tensor.cast %30 : tensor<f32> -> !hal.buffer_view
  return %31 : !hal.buffer_view
}
...
// -----// IR Dump After ConvertToFlowTensorOps //----- //
...
// -----// IR Dump After Canonicalizer //----- //
...
^bb5(%18: f32):  // pred: ^bb1
  %19 = flow.tensor.splat %18 : tensor<f32>
  %20 = hal.tensor.cast %19 : tensor<f32> -> !hal.buffer_view
  return %20 : !hal.buffer_view
}
...
// from flow.tensor.splat -> hal.command_buffer.fill_buffer lowering
error: 'hal.command_buffer.fill_buffer' op operand #4 must be 32-bit signless integer, but got 'f32'
// -----// IR Dump After mlir::iree_compiler::IREE::HAL::`anonymous-namespace'::ConvertToHALPass Failed //----- //

Need some different way to get from the f32 scalar to a hal.buffer_view? Or that shouldn't be detensored? (I can think of a few ways to get around this, but I'm wondering what would make the most sense - @benvanik ?)


Here's some IR from the TOSA if test before/after: https://gist.github.com/ScottTodd/24aea1903bebe41dc74af234ee325b68. I don't yet understand why the %3 = flow.tensor.load %2 : tensor<i8> is not getting turned into a %0 = hal.buffer.load<%buffer_1 : !hal.buffer>[%off] : i8 like it is in the baseline.

benvanik commented 3 years ago

For the f32 fill we should insert a bitcast - which I don't think exists upstream or in the VM yet. We can allow the hal op to take any 32-bit value (or anything <=4 bytes) and then when converting hal->vm can insert the cast. I had thought I needed this for something else and added it but maybe in some long-lost branch.

ScottTodd commented 3 years ago

For the f32 fill we should insert a bitcast - which I don't think exists upstream or in the VM yet. We can allow the hal op to take any 32-bit value (or anything <=4 bytes) and then when converting hal->vm can insert the cast. I had thought I needed this for something else and added it but maybe in some long-lost branch.

bitcast could be coming to upstream (standard ops) with @GMNGeoffrey 's https://llvm.discourse.group/t/rfc-introduce-a-bitcast-op/3774 + https://reviews.llvm.org/D105376

benvanik commented 3 years ago

Ah maybe that's what I was remembering - needing it, seeing that work getting done, but then it not landing and me getting distracted and forgetting why I needed it :P

ScottTodd commented 3 years ago

I'll plan on implementing that f32 fill via a bitcast. May have some more specific questions as I get into the details :)

ergawy commented 3 years ago

tosa_ops/if.mlir fails (and is not fully detensored, see the cost model discussions above)

I am working on patching the cost model to account for that situation.

ergawy commented 3 years ago

Here is a patch for this: https://reviews.llvm.org/D107358.

ScottTodd commented 3 years ago

Here is a patch for this: https://reviews.llvm.org/D107358.

Nice! The TOSA if program looks like this with that commit, after detensoring and cleanup:

func @if_true_test(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {
  %c10_i32 = constant 10 : i32
  %0 = hal.tensor.cast %arg0 : !hal.buffer_view -> tensor<i1>
  %1 = hal.tensor.cast %arg1 : !hal.buffer_view -> tensor<i32>
  %2 = tensor.extract %0[] : tensor<i1>
  %3 = tensor.extract %1[] : tensor<i32>
  cond_br %2, ^bb1, ^bb2(%3 : i32)
^bb1:  // pred: ^bb0
  %4 = tensor.extract %1[] : tensor<i32>
  %5 = addi %4, %c10_i32 : i32
  br ^bb2(%5 : i32)
^bb2(%6: i32):  // 2 preds: ^bb0, ^bb1
  %7 = flow.tensor.splat %6 : tensor<i32>
  %8 = hal.tensor.cast %7 : tensor<i32> -> !hal.buffer_view
  return %8 : !hal.buffer_view
}

(it still fails to completely compile with similar logs to my comment above, but this does simplify it further :D)

benvanik commented 3 years ago

do you have CSE running after that? I'd expect %4 to be deduped with %3

ScottTodd commented 3 years ago

do you have CSE running after that? I'd expect %4 to be deduped with %3

Not immediately, but that does get picked up by a later run: https://github.com/google/iree/blob/c3ea7a879abfa86aa7c0994afba2b19d7354a459/iree/compiler/Dialect/Flow/Transforms/Passes.cpp#L180-L183

benvanik commented 3 years ago

good to look at the IR after that point - otherwise like above it looks like there are still tensor ops in the inner loop :) if you are running this via iree-opt you can build your own pipeline: -pass-pipeline=iree-flow-blah-blah,canonicalize,cse

ScottTodd commented 3 years ago

I'm noticing that after detensoring and using the bitcast op, all of the computation in the collatz model was moved to the VM (IR dumps here), which then requires the -iree-vm-target-extension=f32 feature (this is expected, since the VM ops I'm adding are using the f32 and f64 extensions). There is no clear error message for that when the extension is not enabled, just a segfault here: https://github.com/google/iree/blob/36d3efcb3fb4fda86c4cbb947ebf17c8ef22131f/iree/compiler/Dialect/VM/Conversion/StandardToVM/ConvertStandardToVM.cpp#L156-L157 after https://github.com/google/iree/blob/36d3efcb3fb4fda86c4cbb947ebf17c8ef22131f/iree/compiler/Dialect/VM/Conversion/TypeConverter.cpp#L84-L90

Is that (requiring the f32 extension when detensoring is used) acceptable, or should those ops/types be converted using emulation in the VM?

Edit: Maybe we'd want to limit detensoring of f32 tensors when the extension is not enabled? Or otherwise guide users through setting compiler flags based on their runtime requirements (to avoid performance cliffs and tricky flag experimentation).

benvanik commented 3 years ago

We should fix it so that emits a proper error :P We build f32 into the runtime by default so we should probably also enable the extension in the compiler by default. We'll want some modes for doing things like disabling f32 detensoring for devices that don't support it on the host, though I feel like the best route there is to not use models that use floats if your target hardware doesn't support it and you want tiny builds (no soft float).

ScottTodd commented 3 years ago

I've tested with a few different models coming from TensorFlow and TOSA, and we can now correctly compile them with detensoring enabled.

The programs that we have in open source right now don't really use flow control or other features that would show the benefits of this clearly, so I hacked around in Colab a bit: https://colab.research.google.com/gist/ScottTodd/08656d9c68b023260b42fa5a907fc340/testfordetensoring.ipynb

This Python code:

class FlowControlModule(tf.Module):
  @tf.function(input_signature=[tf.TensorSpec([256, 256], tf.int32), tf.TensorSpec([], tf.int32)])
  def flow_control(self, values, iterations):
    result = 0
    for _ in range(0, iterations):
      values = tf.linalg.matmul(values, values)
      sum = tf.math.reduce_sum(values)
      if sum % 2 == 0:
        result += sum
      else:
        result -= sum
    return result

produces this MLIR, which benefits from detensoring, but only when the "aggressive mode" option is set (is that working as intended? should we set it to true for all of our programs?). Here is the IR before/after detensoring to see what LinalgDetensorize had to work with.

Here's a runtime comparison of that program on the CPU (running 10 iterations of the loop, with 0s for the input matrix): image

and with Vulkan: image

benvanik commented 3 years ago

I'm confused - why is ext so much slower?

ScottTodd commented 3 years ago

I'm confused - why is ext so much slower?

Hmmm... not sure - the overall time is improved though. Probably has to do with the "normalized values" checkbox. Here are the Vulkan traces if you want to look: https://drive.google.com/file/d/1TaZ4vY6Qiq6WY6Bct_myfDHlLo7rzn4_/view?usp=sharing

benvanik commented 3 years ago

oh maybe because you are measuring submit_and_wait, which is run outside of the main execution - measuring "iree_vm_invoke" would be better - goes from 74 to 60 in the vulkan case

benvanik commented 3 years ago

(also a good idea to not run with system tracing when trying to get accurate numbers - the system tracing adds a lot of overhead/variance as it has some threads at 100% and such)

silvasean commented 3 years ago

I would expect the default mode to work on that program. If not, it's probably just a bug / feature request.

ergawy commented 3 years ago

As @silvasean mentioned, the default mode should have worked here, so most probably a bug from a quick look at the code. Will take a look...

ergawy commented 3 years ago

I finally had the time to look into this, sorry I am in the process of changing jobs and moving from Berlin to Munich at the moment.

The issue is in detecting the control flow ops for the if condition. Here is what happens:

I think, it's time to make it a bit more flexible. The examples we encountered so far didn't need such flexibility. Will look into that and again sorry for the slowness here, life is very busy at the moment.

ergawy commented 3 years ago

I have a potential fix for the issue. It is more of removing a self-imposed restriction in the code than a fix. You can find the relevant commit here: https://github.com/ergawy/llvm-project/commit/0dbd1e8c28557be3b4d673120caa1f6c470a6ff3.

With this change, the only remaining linalg.generic op after detensoring in the reduction op as expected (and of course the other regression test pass expect for one that was written specifically to test that detesoring fails with the above mentioned restriction).

You can find the output of the previously failed example here: https://gist.github.com/ergawy/b528f1db884b59ec60d2d7ebac69776f.

@ScottTodd I think it makes sense to test with the above fix since the amount of changed code is tiny and then if things work smoothly, I will open a PR to MLIR. WDYT?

ScottTodd commented 3 years ago

SGTM, thanks @ergawy ! (and no worries about being busy, the help is appreciated!)

I ran a build of IREE with detensoring enabled and that change, and the affected tests still build+pass.

I spent the last few weeks trying to find open source models with interesting flow control and couldn't actually import any of the models that I looked at :/. I think for now we'll have to live with the existing small scale programs for primary correctness + performance coverage.

ergawy commented 3 years ago

I started a patch to address the above issue: https://reviews.llvm.org/D109965.