iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.86k stars 624 forks source link

Dispatch region formation is capturing constants as read/write and mustn't. #12233

Open benvanik opened 1 year ago

benvanik commented 1 year ago

I've noticed that constants get captured as !flow.dispatch.tensor<readwrite...>. This has a lot of implications and is currently only barely working due to a combination of other things we do (like very early constant inlining) and may be hiding performance issues (we can't mark bindings as read-only when lowering to LLVM/SPIR-V/etc). This seems to have been introduced by the changes around dispatch region formation to first go through the region op and then go to the workgroups op.

Root issue is that linalg ops specify constant values as outputs:

func.func @capture_cst(%arg0: tensor<4x32xi32>) -> tensor<32xi32> {
  %cst = arith.constant dense<0> : tensor<32xi32>
  %c1 = arith.constant 1 : index
  %c0 = arith.constant 0 : index
  %c1_0 = arith.constant 1 : index
  %0 = affine.apply affine_map<()[s0, s1, s2] -> ((s1 - s0) ceildiv s2)>()[%c0, %c1, %c1_0]
  %c0_1 = arith.constant 0 : index
  %c32 = arith.constant 32 : index
  %c1_2 = arith.constant 1 : index
  %1 = affine.apply affine_map<()[s0, s1, s2] -> ((s1 - s0) ceildiv s2)>()[%c0_1, %c32, %c1_2]
  %2 = flow.dispatch.region[%0, %1] -> (tensor<32xi32>) {
    %3 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1)>], iterator_types = ["reduction", "parallel"]} ins(%arg0 : tensor<4x32xi32>) outs(%cst : tensor<32xi32>) {
    ^bb0(%in: i32, %out: i32):
      %4 = arith.addi %in, %out : i32
      linalg.yield %4 : i32
    } -> tensor<32xi32>
    flow.return %3 : tensor<32xi32>
  } count(%arg1: index, %arg2: index) -> (index, index, index) {
    %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2
    flow.return %x, %y, %z : index, index, index
  }
  return %2 : tensor<32xi32>
}

When forming the dispatch region the code looks at the linalg op, sees that %cst is an output, and marks it as readwrite:

func.func @capture_cst(%arg0: tensor<4x32xi32>) -> tensor<32xi32> {
  %cst = arith.constant dense<0> : tensor<32xi32>
  %c1 = arith.constant 1 : index
  %c0 = arith.constant 0 : index
  %c1_0 = arith.constant 1 : index
  %0 = affine.apply affine_map<()[s0, s1, s2] -> ((s1 - s0) ceildiv s2)>()[%c0, %c1, %c1_0]
  %c0_1 = arith.constant 0 : index
  %c32 = arith.constant 32 : index
  %c1_2 = arith.constant 1 : index
  %1 = affine.apply affine_map<()[s0, s1, s2] -> ((s1 - s0) ceildiv s2)>()[%c0_1, %c32, %c1_2]
  %2 = flow.dispatch.workgroups[%0, %1](%arg0, %cst) : (tensor<4x32xi32>, tensor<32xi32>) -> %cst =
      (%arg1: !flow.dispatch.tensor<readonly:tensor<4x32xi32>>, %arg2: !flow.dispatch.tensor<readwrite:tensor<32xi32>>) {
    %3 = flow.dispatch.tensor.load %arg1, offsets = [0, 0], sizes = [4, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<4x32xi32>> -> tensor<4x32xi32>
    %4 = flow.dispatch.tensor.load %arg2, offsets = [0], sizes = [32], strides = [1] : !flow.dispatch.tensor<readwrite:tensor<32xi32>> -> tensor<32xi32>
    %5 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1)>], iterator_types = ["reduction", "parallel"]} ins(%3 : tensor<4x32xi32>) outs(%4 : tensor<32xi32>) {
    ^bb0(%in: i32, %out: i32):
      %6 = arith.addi %in, %out : i32
      linalg.yield %6 : i32
    } -> tensor<32xi32>
    flow.dispatch.tensor.store %5, %arg2, offsets = [0], sizes = [32], strides = [1] : tensor<32xi32> -> !flow.dispatch.tensor<readwrite:tensor<32xi32>>
    flow.return
  } count(%arg1: index, %arg2: index) -> (index, index, index) {
    %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2
    flow.return %x, %y, %z : index, index, index
  }
  return %2 : tensor<32xi32>
}

The stream dialect correctly handles this by introducing a copy-on-write but this is the least efficient way of doing this kind of work. We should not be doing this in-place - doing so is a pessimization: it results in a read of the input + a write to the temporary + a full barrier + a read again on the inside of the dispatch + a write again of the output. So we end up needing to reserve the new memory for the output (what we need to do anyway) and then issue a copy from the constant memory to the temporary output that must complete prior to launching the dispatch. We don't see this in benchmarks much because the stream dialect schedules the copies concurrently with other dispatches but it increases memory consumption (something we don't track), code size (which we do but no one is looking at it), and adds memory bandwidth pressure that only shows as hurting dispatches around the dispatches doing this instead of being attributed to the dispatches that introduce the issue.

The IR above comes when disabling inlining of small constants during FormDispatchRegions, but if either the maximum size to inline is smaller or the tensor is larger than 256 bytes it'll look like the above today. This means that if you have a 100MB constant we'll reserve 100MB of memory, memcpy the whole constant over, and then do all those extra reads/writes to it in a very unlikely-to-be-cache-friendly manner.

Maybe there's a pass we could run to decouple linalg ops writing into constants? The above could be transformed into a tensor.empty as output and then just read the constant instead. Maybe this is related to the linalg decomposition stuff?

Since we barely manage to squeak by today for small constants with the current FormDispatchRegions constant inlining most people don't notice this in unit tests but it's blocking improvements to deduplication we need to reduce compile time and code size. It's also confounding other optimization work deeper down as with readwrite we can't add no-alias attributes during code generation and we're generating more code than we need to (all the additional loads/etc).

MaheshRavishankar commented 1 year ago

I think the issue is the input linalg.generic

 %3 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1)>], iterator_types = ["reduction", "parallel"]} ins(%arg0 : tensor<4x32xi32>) outs(%cst : tensor<32xi32>) {
    ^bb0(%in: i32, %out: i32):
      %4 = arith.addi %in, %out : i32
      linalg.yield %4 : i32
    } -> tensor<32xi32>

is ill-formed. This should be split out as

%0 = tensor.empty()
%1 = linalg.fill (%cst : f32) outs(%0)
%2 = linalg.matmul ins(...) outs(%1)

I would say its a front-end problem, but we can try to massage the IR within IREE to the right form...

@julianwa @jpienaar anyone who could pick this up?

benvanik commented 1 year ago

is that something we could add a check for to prevent frontends from passing it in cases where we couldn't fix it? (feels like we'll likely need to accept it and handle that transformation or make it a verification failure or something?)

MaheshRavishankar commented 1 year ago

But what I suggested above only works if this is a splat constant. Wont work for non-splat constants.... So if we dont know that the operand to the dispatch is a constant, we have to treat it as read-write (or else bufferization will end creating an allocation). Not sure how to fix this....

benvanik commented 1 year ago

I'm ok with copy-on-write if we don't know - but in 99% of cases we do know :) (perfect-enemy-of-good situation)

jpienaar commented 1 year ago

Was this original lowering from the TOSA one discussed on Discord? (and so front-end here is in the lowering to linalg that the other form should be used).

benvanik commented 1 year ago

The one I saw above was in a test, but I've seen this in several real models from tosa. It's hard to piece apart what's what - there may be multiple independent issues or several related to this same process (what to capture).

For example, this from https://gist.github.com/bjacob/4d1a0728f9e6814178b50646e62b27ce:

    %55 = flow.tensor.splat %cst_748 : tensor<1536xf32>
    %57 = flow.dispatch @main_dispatch_32::@main_dispatch_32_generic_1536x384[%c1536, %c1](%56, %55) : (tensor<1536x384xf32>, tensor<1536xf32>) -> %55

(should be capturing the splat and returning it, not in-place operating on it)

(nevermind the tensor.empty issue - that's a concat and correct)

Some of the models I've been looking at have constants moved out to globals prior to lowering into linalg - if that happens we'll need a slightly smarter check than just isa<ConstantOp> but I'm happy to do that stuff (walk util globals and such) if there's a place I can put that logic. tosa starts with constants inline (tosa.const) and they remain in arith.constant until dispatch region formation.

MaheshRavishankar commented 1 year ago

Yeah, todays readwrite setting logic does not look past the dispatch. Maybe that just needs to be added. But making it write-only also puts us into a spot when we fuse with this operation. If we dont vectorize it will end up a stack allocation. For GEMM kind of stuff we always vectorize, so its OK. Issue is that if we dont mark it read-write we wont be in destination passing style for bufferization to work as expected. So yeah, there are a few things mixed up here...

benvanik commented 1 year ago

Yikes, yeah.... sigh. Stack would probably be better as it's bounded to the workgroup at least and in local cache, whereas this kind of stuff is materializing global memory and doing expensive dispatches: extending the %55 lifetime up into whatever is above it, launching the splat operation, and synchronizing before launching the dispatch consuming it pretty much ensures cold caches and non-trivial additional overhead (1-10 microseconds even if the splat is just 1 byte).

MaheshRavishankar commented 1 year ago

Yikes, yeah.... sigh. Stack would probably be better as it's bounded to the workgroup at least and in local cache, whereas this kind of stuff is materializing global memory and doing expensive dispatches: extending the %55 lifetime up into whatever is above it, launching the splat operation, and synchronizing before launching the dispatch consuming it pretty much ensures cold caches and non-trivial additional overhead (1-10 microseconds even if the splat is just 1 byte).

The stack is bounded, but it is still not the code you want to generate for the dispatch. Worth the trade-off though it seems like...

benvanik commented 1 year ago

Oh, I guess I see a lot more of these with #12235 and --iree-flow-inline-constants-max-byte-length=0 (which prevents inlining of all constants, including splat attrs - which is something we need for reducing executable count). We may have to do something smarter in that case with splats specifically by turning them into broadcasts of a dynamic value and not then inlining that dynamic value - the broadcast should still end up in the dispatch regions and we'd just be capturing an int/float primitive.

Any ideas of the best way to do that? Uneducated guess is that a pass prior to FormDispatchRegions could run and split arith.constant tensor splats into an arith.constant primitive + linalg.generic broadcasting (or some other op sequence like empty + fill) and let FormDispatchRegions do the right thing (isClonableIntoDispatchOp always takes fills). I try to avoid order-dependent passes but materializing new ops like that within FormDispatchRegions feels like it'd be more complex?

allieculp commented 1 year ago

@benvanik @julianwa @jpienaar @MaheshRavishankar This one got a bit stale, should we prioritize this or no?

MaheshRavishankar commented 1 year ago

Well, I cant do this right now... It is worth having someone look into this a bit though... there a lot of these things that are not exactly bit-size but good starter-ish tasks that will help a lot with things.