iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.59k stars 581 forks source link

Improve code size by folding identical dispatches #9895

Open dcaballe opened 2 years ago

dcaballe commented 2 years ago

Request description

Opening this mostly for discussion. I'm seeing quite a few dispatches in mobilebert-quant that are identical except by the constants consumed by tosa.apply_scale. They look like this:

        %3 = linalg.generic {indexing_maps = [#map9, #map10, #map11], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%0, %1 : tensor<4x1x384x384xi8>, tensor<384x1x384xi8>) outs(%2 : tensor<4x384x1x384xf32>) {                                                                                                        
        ^bb0(%arg3: i8, %arg4: i8, %arg5: f32):                                                                                                                         
          %4 = arith.extsi %arg4 : i8 to i32                                                                                                                            
          %5 = arith.subi %4, %c127_i32 : i32                                                                                                                           
          %6 = "tosa.apply_scale"(%5, %c1073741824_i32, %c11_i8) {double_round = false} : (i32, i32, i8) -> i32                                                         
          %7 = arith.cmpi slt, %6, %c-2147483648_i32 : i32                                                                                                              
          %8 = arith.select %7, %c-2147483648_i32, %6 : i32                                                                                                             
          %9 = arith.cmpi slt, %c2147483647_i32, %6 : i32                                                                                                               
          %10 = arith.select %9, %c2147483647_i32, %8 : i32                                                                                                             
          %11 = arith.extsi %arg3 : i8 to i32                                                                                                                           
          %12 = arith.subi %11, %c36_i32 : i32                                                                                                                          
          %13 = "tosa.apply_scale"(%12, %c1289686363_i32, %c12_i8) {double_round = false} : (i32, i32, i8) -> i32                                                       
          %14 = arith.cmpi slt, %13, %c-2147483648_i32 : i32                                                                                                            
          %15 = arith.select %14, %c-2147483648_i32, %13 : i32                                                                                                          
          %16 = arith.cmpi slt, %c2147483647_i32, %13 : i32                                                                                                             
          %17 = arith.select %16, %c2147483647_i32, %15 : i32                                                                                                           
          %18 = arith.addi %17, %10 : i32                                                                                                                               
          %19 = "tosa.apply_scale"(%18, %c1341710240_i32, %c50_i8) {double_round = true} : (i32, i32, i8) -> i32                                                        
          %20 = arith.addi %19, %c93_i32 : i32                                                                                                                          
          %21 = arith.cmpi slt, %20, %c-128_i32 : i32                                                                                                                   
          %22 = arith.select %21, %c-128_i32, %20 : i32                                                                                                                 
          %23 = arith.cmpi slt, %c127_i32, %20 : i32                                                                                                                    
          %24 = arith.select %23, %c127_i32, %22 : i32                                                                                                                  
          %25 = arith.trunci %24 : i32 to i8                                                                                                                            
          %26 = arith.sitofp %25 : i8 to f32                                                                                                                            
          %27 = arith.subf %26, %cst : f32                                                                                                                              
          %28 = arith.mulf %27, %cst_0 : f32                                                                                                                            
          %29 = math.exp %28 : f32                                                                                                                                      
          linalg.yield %29 : f32                                                                                                                                        
        } -> tensor<4x384x1x384xf32>

Not sure if keeping all these instances of the same dispatch makes sense. Perhaps keeping the constants as constants leads to any constant folding at LLVM level... but this looks like a lot of code being duplicated. I wonder if it would make sense to do some CSE at dispatch level. WDYT?

What component(s) does this issue relate to?

Compiler

Additional context

No response

benvanik commented 2 years ago

Yep! We've got a few issues tracking this (usually named something like "deduplicating dispatch executables"). #1144 is an ancient one that points out some more example of duplication. It's a big problem that would be great to have some eyes and hands on.

Brain-dump:

The primary mechanism we have today is single pass of deduplication after we outline the flow.executables (DeduplicateExecutables.cpp) and it catches quite a bit but that's only checking the IR for an exact match. The above IR is a case where the exact match is insufficient and it happens a lot in quantized models due to these offsets/biases/etc. We also have issues where because the dispatches are often operating on statically-shaped tensors we end up with one executable per unique shape and in conv-heavy models it'll happen on every layer as they are usually a size cascade. We're often able to cut 20-60% of the dispatches but in most models we should be able to cut 80-90% with a small amount of work.

A general approach we'll want to take that touches several parts of the stack is to purposefully discard static information when it does not provide value. We want the maximum amount of beneficial information and the minimum amount of noise. This unfortunately requires a decision to be made as to when something is not worth knowing at a subsequent point in compilation and I think that's why there hasn't been much traction on it: it's a potential quagmire to perfection-focused engineers :)

One way we have of making that tractable is the potential value sets we populate on dispatches before final codegen: https://github.com/iree-org/iree/blob/59602cce17b87beea9e8353b6f4e765b1e968c8b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/annotate_dispatch_arguments.mlir#L13-L18 We annotate each dispatch argument with a set of all known values from all dispatch sites as well as known alignments (allowing us to know that a padded dimension is %64=0 even if the base value is dynamic and unknowable). These show in the IR on hal.interface.constant.load ops and it's up to each backend to use as they want (most ignore them today). Paired with that we also perform a late-stage inlining of dispatch arguments that have only a single potential value: so in cases where there's a single dispatch site or all dispatch sites pass the same value that'll end up as an arith.constant inside the dispatch prior to codegen lowering.

Conceptually this is like having a C++ template and then a set of all values for all instantiations of that template and folding them - today we're just baking out each instantiation independently but could be doing nearly anything and getting meaningful improvements. What this means is that we can have more dynamism in the dispatch region we form in order to get better deduplication while still providing codegen with the information we removed (potential values/alignment).

I think the major work is some larger changes for hoisting constants and making static shapes dynamic such that our IR equality deduplication triggers. I'm not sure anyone is actively looking at that (but it'd be really good to do so!). We don't need 60 bespoke convs but maybe ~4, and the 4 we decide to generate could be derived from the alignment/potential values in most cases ("we have sizes <16, <32, <64, and <512, so let's generate one for each size category"). Since we gather alignments where possible we can do this even in dynamic cases and emit the code to switch at runtime - a single dispatch function with a switch statement into 4 code paths is better than 60 nearly-identical (and often-times perfectly identical after lowering to machine code) functions or 1 sub-par one.

Perfect is the enemy of good here and we're currently abysmal so I'm in full support of people chipping away at these things as they find opportunities. A way to reason about this is like lossless compression: we want to produce the same results with as little code as possible and are willing to take small losses in performance to get large gains in deduplication in all but the exceptional cases we don't focus on (HPC): going from 500KB->50KB of generated code for a 1% performance loss is favorable to the project and we can always have an "inline-and-unroll-and-expand-everything" if needed. A pass that cloned and specialized dispatch functions for each dispatch site is easy to implement and would preserve our current behavior (every dispatch function is a snowflake) while making the much harder to solve problem (reducing dispatch function count) easier to solve. It'd also have the benefit of improving compilation time: the fewer dispatch functions that need to be lowered the better!

dcaballe commented 2 years ago

Thanks for the detailed description, Ben! Super interesting! I totally agree with the points that you made. The decision of discarding static information is something that we could make on a op by op basis. We should get a clear picture of where discarding the information is a no-go with basic micro-benchmarking.

Regarding the multi-version approach that you describe, I've seen it working very well in the past, esp. when each version is aware of not only its upper bound but also its lower bound sizes (e.g., version targeting <512 won't deal with sizes <64).

There is also an outliner in LLVM but I have the impression that it would be too late to do this kind to coarse-grain deduplication and I'm not sure if it works across functions. Just for the records.

benvanik commented 2 years ago

The LLVM outliner may help on the CPU case (though I haven't seen it do anything - maybe there's a flag we can set?) but you're right that we want to do it early on: targets like SPIR-V and PTX can't use a technique like that as each dispatch function is compiled independently. There still may be a case for outlining, though, and we can also do that manually: each executable can have any number of entry points and each can call other functions so if we want to emit a function with multiple incoming calls having specialized arguments/nested in different loop structures/etc we could improve reuse on targets that did link everything together (just VMVX and LLVM CPU today, but possibly Metal in the future).

dcaballe commented 2 years ago

(though I haven't seen it do anything - maybe there's a flag we can set?)

Yes, in IPO. It's not enabled by default: https://github.com/llvm/llvm-project/blob/main/llvm/lib/Transforms/IPO/PassManagerBuilder.cpp#L97

benvanik commented 2 years ago

Ah yes MPM.add(createIROutlinerPass()); - it'd be useful to have someone who is familiar with the available passes (you? :) check what ones we're running - it looks like we're just doing the default compiler options for O3 and none of the linker passes: https://github.com/iree-org/iree/blob/2b91ddc36c84de648aa200bb82d543530d7b0626/compiler/src/iree/compiler/Dialect/HAL/Target/LLVM/LLVMIRPasses.cpp#L113-L114

benvanik commented 2 years ago

(oh, and our linker options are here: https://github.com/google/iree/blob/main/compiler/src/iree/compiler/Dialect/HAL/Target/LLVM/internal/EmbeddedLinkerTool.cpp#L154 - but if we can do anything in the bitcode before we pass it on to the linker that's going to be more robust/reliable)

dcaballe commented 2 years ago

Thanks! I can take a look and run a few tests. Unfortunately, though, I still have some gaps in the IREE's linking strategies. I need to spend some time digging into the embedded linker and the system linkers. In this regard, I wonder if it would make sense to enable LTO, if we don't have it already. I'm also not sure why we are defaulting to O3. O3 enables more aggressive heuristics, including for unrolling/inlining, etc., which usually ends up with larger code sizes and sometimes performance is not significantly better (only sometimes). We can give O2 a try. It would be great to know more about these historical decisions :)

benvanik commented 2 years ago

No one has ever really owned the options nor spent any time tuning them - for things relating to optimization consider what's there as throw-away and desperately in need of an owner :)

We only link a single .o today so much of what LTO in the linker could give us we could do ourselves when generating the .o by using the same passes. That ensures we have consistent behavior across embedded/system linkers: the system linker is toolchain-provided and not something we control. Basically, embedded linker is "lld we control and that is at the same version as the iree compiler for all platforms" and system linker is "whatever random binary can produce a dynamic library for a host system" (it could be ld, lld, link, Apple's forked stuff from Xcode, etc). When emitting static libraries we don't have any control over the linker - the user is responsible for linking in the object file - so nothing we rely on the linker for optimization can apply there. So the more we do before passing it off to the linker the better.

MaheshRavishankar commented 2 years ago

Thanks! I can take a look and run a few tests. Unfortunately, though, I still have some gaps in the IREE's linking strategies. I need to spend some time digging into the embedded linker and the system linkers. In this regard, I wonder if it would make sense to enable LTO, if we don't have it already. I'm also not sure why we are defaulting to O3. O3 enables more aggressive heuristics, including for unrolling/inlining, etc., which usually ends up with larger code sizes and sometimes performance is not significantly better (only sometimes). We can give O2 a try. It would be great to know more about these historical decisions :)

Agreed! O2 seems reasonable if it helps balance some of the code size issues, even if the default performance drops by a bit in some cases..... I think most of the real optimizations we need kick in at O2.

dcaballe commented 2 years ago

(Is there a way to attach PRs to Issues without closing the Issues when the PRs are merged?)

MaheshRavishankar commented 2 years ago

Dont add the issue to the PR title. Just say Issue <issue-number> in the description field.

ScottTodd commented 1 year ago

I've been looking at this issue (similar but not identical dispatches) as the number of dispatches correlates directly with compile time and some one our CI workflows are bottlenecked on that time spent compiling.

mobilebert-float has 30 executables while mobilebert-quant has

Here's a sample IR dump: mobilebert-baseline-tf2-quant_noinlining_flow_2023_02_16.mlir. That was generated from the mobilebert-baseline-tf2-quant.tflite.mlir that we use in our performance benchmarking with these flags:

iree-compile mobilebert-baseline-tf2-quant.tflite.mlir \
  --iree-input-type=tosa \
  --iree-hal-target-backends=llvm-cpu \
  --iree-llvm-target-cpu=cascadelake \
  --iree-llvm-target-triple=x86_64-unknown-linux-gnu \
  --iree-flow-inline-constants-max-byte-length=0 \
  --compile-to=flow

(Transcribing from this discussion on Discord)

Different fusions

The dispatches in this gist are doing the same math, but with different arguments.

The dispatches in this gist follow the same pattern, but with the same number of arguments (i8 and i32 permuted?)

the first one is going to be tricky as that's different fusion, but microkernels will help there (the bulk of the code we generate today should be in the microkernels that will be shared)

I'd say the first one's solution is just "microkernels" so we can re-evaluate after that (we may still have lots of dispatches but they should be sharing the same microkernels)

Static size/offset/strides could be dynamic

There are three dispatches in this gist that just load then store at different static offsets. These don't contribute much to compile time, but they could still be folded.

the second one is static info that should be dynamic (we could take those offsets/sizes/strides and make them arguments instead of attributes)

the second we should have an issue for despecializating those static offsets/sizes/strides (possibly when constructing dispatch regions we always force dynamic)

Same except for an extra subtraction

The dispatches in this gist look very similar - except one of them starts with an extra arith.subi op and has one extra argument for what to subtract.

last one should probably also get an issue cc'ed to benoit/mahesh about ensuring we aren't prematurely decomposing ops in a way that makes recovering the semantics too difficult