iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.83k stars 609 forks source link

Fix SDXL dispatch count regression from llvm integrate #19002

Open IanWood1 opened 2 days ago

IanWood1 commented 2 days ago

The most recent llvm integrate, https://github.com/iree-org/iree/pull/18987, introduced a minor regression in SDXL clip dispatch count (1139 ⇾ 1141). I tracked it to https://github.com/llvm/llvm-project/commit/df0d249b6511289f1e8c1389f4fd33d7b4c083fa. I was able to restore the dispatch count by locally reverting this single commit.

Here are the 2 additional dispatches after LLVM integrate:

Command used:

iree-compile artifacts/sdxl_clip/model.mlirbc -o extra-dispatches.mlir --iree-hal-target-backends=rocm --iree-hip-target=gfx942 --iree-opt-const-eval=false --iree-global-opt-propagate-transposes=true --iree-dispatch-creation-enable-fuse-horizontal-contractions=true --iree-dispatch-creation-enable-aggressive-fusion=true --iree-opt-aggressively-propagate-transposes=true --iree-opt-outer-dim-concat=true --iree-llvmgpu-enable-prefetch=true --iree-opt-data-tiling=false --iree-codegen-gpu-native-math-precision=true --iree-codegen-llvmgpu-use-vector-distribution --iree-hip-waves-per-eu=2 --iree-execution-model=async-external --iree-scheduling-dump-statistics-format=json --iree-scheduling-dump-statistics-file=compilation_info.json '--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline,iree-preprocessing-pad-to-intrinsics)' --compile-to=dispatch-creation

MLIR:

util.global private @__hoisted_tensor_64x768xf16_255 {stream.affinity.default = #hal.device.affinity<@__device_0>} : tensor<64x768xf16>
  util.initializer attributes {stream.affinity.default = #hal.device.affinity<@__device_0>} {
    %cst = arith.constant dense_resource<torch_tensor_1_77_torch.int64> : tensor<1x77xi64>
    %_params.text_encoder_model_1.text_model.embeddings.position_embedding.weight = util.global.load immutable @_params.text_encoder_model_1.text_model.embeddings.position_embedding.weight : tensor<77x768xf16>
    %0 = flow.dispatch.workgroups(%cst, %_params.text_encoder_model_1.text_model.embeddings.position_embedding.weight) : (tensor<1x77xi64>, tensor<77x768xf16>) -> tensor<64x768xf16> =
        (%arg0: !flow.dispatch.tensor<readonly:tensor<1x77xi64>>, %arg1: !flow.dispatch.tensor<readonly:tensor<77x768xf16>>, %arg2: !flow.dispatch.tensor<writeonly:tensor<64x768xf16>>) {
      %c77 = arith.constant 77 : index
      %c0_i64 = arith.constant 0 : i64
      %1 = flow.dispatch.tensor.load %arg1, offsets = [0, 0], sizes = [77, 768], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<77x768xf16>> -> tensor<77x768xf16>
      %2 = tensor.empty() : tensor<64x768xf16>
      %3 = flow.dispatch.tensor.load %arg0, offsets = [0, 0], sizes = [1, 64], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<1x77xi64>> -> tensor<64xi64>
      %4 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%3 : tensor<64xi64>) outs(%2 : tensor<64x768xf16>) {
      ^bb0(%in: i64, %out: f16):
        %5 = arith.index_cast %in : i64 to index
        %6 = linalg.index 1 : index
        %7 = arith.cmpi slt, %5, %c77 : index
        cf.assert %7, "index must be smaller than dim size"
        %8 = arith.cmpi sge, %in, %c0_i64 : i64
        cf.assert %8, "index must be larger or equal to 0"
        %extracted = tensor.extract %1[%5, %6] : tensor<77x768xf16>
        linalg.yield %extracted : f16
      } -> tensor<64x768xf16>
      flow.dispatch.tensor.store %4, %arg2, offsets = [0, 0], sizes = [64, 768], strides = [1, 1] : tensor<64x768xf16> -> !flow.dispatch.tensor<writeonly:tensor<64x768xf16>>
      flow.return
    } count() -> (index, index, index) {
      %x, %y, %z = flow.dispatch.workgroup_count_from_slice 
      flow.return %x, %y, %z : index, index, index
    }
    util.global.store %0, @__hoisted_tensor_64x768xf16_255 : tensor<64x768xf16>
    util.return
  }
  util.global private @__hoisted_tensor_64x1280xf16_256 {stream.affinity.default = #hal.device.affinity<@__device_0>} : tensor<64x1280xf16>
  util.initializer attributes {stream.affinity.default = #hal.device.affinity<@__device_0>} {
    %cst = arith.constant dense_resource<torch_tensor_1_77_torch.int64_1> : tensor<1x77xi64>
    %_params.text_encoder_model_2.text_model.embeddings.position_embedding.weight = util.global.load immutable @_params.text_encoder_model_2.text_model.embeddings.position_embedding.weight : tensor<77x1280xf16>
    %0 = flow.dispatch.workgroups(%cst, %_params.text_encoder_model_2.text_model.embeddings.position_embedding.weight) : (tensor<1x77xi64>, tensor<77x1280xf16>) -> tensor<64x1280xf16> =
        (%arg0: !flow.dispatch.tensor<readonly:tensor<1x77xi64>>, %arg1: !flow.dispatch.tensor<readonly:tensor<77x1280xf16>>, %arg2: !flow.dispatch.tensor<writeonly:tensor<64x1280xf16>>) {
      %c77 = arith.constant 77 : index
      %c0_i64 = arith.constant 0 : i64
      %1 = flow.dispatch.tensor.load %arg1, offsets = [0, 0], sizes = [77, 1280], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<77x1280xf16>> -> tensor<77x1280xf16>
      %2 = tensor.empty() : tensor<64x1280xf16>
      %3 = flow.dispatch.tensor.load %arg0, offsets = [0, 0], sizes = [1, 64], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<1x77xi64>> -> tensor<64xi64>
      %4 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%3 : tensor<64xi64>) outs(%2 : tensor<64x1280xf16>) {
      ^bb0(%in: i64, %out: f16):
        %5 = arith.index_cast %in : i64 to index
        %6 = linalg.index 1 : index
        %7 = arith.cmpi slt, %5, %c77 : index
        cf.assert %7, "index must be smaller than dim size"
        %8 = arith.cmpi sge, %in, %c0_i64 : i64
        cf.assert %8, "index must be larger or equal to 0"
        %extracted = tensor.extract %1[%5, %6] : tensor<77x1280xf16>
        linalg.yield %extracted : f16
      } -> tensor<64x1280xf16>
      flow.dispatch.tensor.store %4, %arg2, offsets = [0, 0], sizes = [64, 1280], strides = [1, 1] : tensor<64x1280xf16> -> !flow.dispatch.tensor<writeonly:tensor<64x1280xf16>>
      flow.return
    } count() -> (index, index, index) {
      %x, %y, %z = flow.dispatch.workgroup_count_from_slice 
      flow.return %x, %y, %z : index, index, index
    }
    util.global.store %0, @__hoisted_tensor_64x1280xf16_256 : tensor<64x1280xf16>
    util.return
  }

It appears these linalg.generics were getting CSE'd before the change but can't anymore because of the cf.assert which have side effects.

benvanik commented 2 days ago

ew, we should really not be seeing those asserts - what's adding those?

IanWood1 commented 2 days ago

I think this is coming from torch-mlir's lowering of some op (not entirely sure which). Don't these get dropped somewhere around flow/stream anyway?

benvanik commented 2 days ago

Nope - they make it all the way to runtime if they are outside of dispatches and as you're seeing here will have bad influences during dispatch region formation/executable generation. Asserts should only be added explicitly by users unless a debug mode is enabled, IMO. Asserts inside of dispatches are no-ops today and get removed too late, so they just make compilation worse.

They could be used for int range analysis hints in a release build - but if that's the case we should probably absorb them into the int range ops at input time instead.

IanWood1 commented 2 days ago

Asserts should only be added explicitly by users unless a debug mode is enabled, IMO. Asserts inside of dispatches are no-ops today and get removed too late, so they just make compilation worse.

That makes sense, I think they are there to conform with pytorch ops specs. We don't currently have a "debug/release mode" right?

@MaheshRavishankar do you have any suggestions on how to fix this?

benvanik commented 2 days ago

--iree-opt-strip-assertions can be used to strip them near input-time (I forget if it walks into linalg ops, but it should). As a middle-stage compiler we want debug options for assertions that come in as user input to be controlled by the user creating the input - it's not possible to know if an assert was added by the user, a dialect conversion above us (like this), etc. If a dialect is inserting assertions it'd be nice if it had an option to stop inserting them.

For now though, --iree-opt-strip-assertions unless you're testing correctness (and even then as seen here we'll never report assertions inside of dispatches today, though we could in debug modes if it ever proved useful - it's just really tricky logic per backend).

MaheshRavishankar commented 2 days ago

I thought that is on by default?

benvanik commented 2 days ago

Doesn't seem like it. It could be. I suspect only a fraction of users care about the assertions and more would be confused by how badly they mess up performance.

MaheshRavishankar commented 2 days ago

@IanWood1 maybe start with adding this flag to all the CI tests, and a separate PR that turns it on by default.

IanWood1 commented 1 day ago

I tried turning it on in https://github.com/iree-org/iree/pull/19014 but I didn't realize assertions don't get stripped until after hoisting, so there is no effect on dispatch count. Should this pass be moved? There is a comment explaining why it need to be after optimizations:

https://github.com/iree-org/iree/blob/2a5d12323c216e275dcc5f955b70aa60d89d47ed/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp#L226-L229

benvanik commented 1 day ago

Good catch. That may not be true anymore now that we have information coming from the frontend and util.assume - we could lower the assertions to those assume ops prior to removal as one of the first steps.

benvanik commented 1 day ago

(oh and I'm pretty sure we don't derive information from the assertions today - so it'd be safe to move now!)