iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.85k stars 620 forks source link

[compiler] VAE dispatch exceeds shared memory limit #19230

Open nithinsubbiah opened 1 day ago

nithinsubbiah commented 1 day ago

What happened?

The following dispatch exceeds the shared memory limit and results in error while compiling vae model. Here's the link to the whole VAE model (https://sharkpublic.blob.core.windows.net/sharkpublic/sai/sdxl-vae-decode/11-20-2024/model.mlir) and the error can be reproduced by the following command:

Down the attention spec mlir from the following link (https://raw.githubusercontent.com/nod-ai/sdxl-scripts/refs/heads/shared/sdxl_on_main/int8-model/specs/attention_and_matmul_spec.mlir) and pass the file path to iree-codegen-transform-dialect-library

iree-compile --iree-input-type=torch --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=rocm --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-hal-target-backends=rocm --iree-hip-target=gfx942 --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-dispatch-creation-enable-aggressive-fusion --iree-dispatch-creation-enable-fuse-horizontal-contractions --iree-opt-aggressively-propagate-transposes=true --iree-codegen-llvmgpu-use-vector-distribution=true --iree-opt-data-tiling=false --iree-codegen-gpu-native-math-precision=true --iree-vm-target-truncate-unsupported-floats --iree-global-opt-propagate-transposes=true --iree-opt-const-eval=false --iree-llvmgpu-enable-prefetch=true --iree-execution-model=async-external --iree-preprocessing-pass-pipeline="builtin.module(util.func(iree-global-opt-raise-special-ops, iree-flow-canonicalize), iree-preprocessing-transpose-convolution-pipeline, iree-preprocessing-pad-to-intrinsics, util.func(iree-preprocessing-generalize-linalg-matmul-experimental))" --iree-codegen-transform-dialect-library=attention_and_matmul_spec.mlir --iree-hal-dump-executable-sources-to=dump model.mlir -o /dev/null

Error:

dump/compiled_vae_encode$async_dispatch_126.mlir:9:6: error: 'func.func' op uses 139296 bytes of shared memory; exceeded the limit of 65536 bytes
      func.func @encode$async_dispatch_126_matmul_like_8x128x128x8_f16xf16xf32() {

Steps to reproduce your issue

Dispatch IR

hal.executable public @encode$async_dispatch_126 {
  hal.executable.variant public @rocm_hsaco_fb target(<"rocm", "rocm-hsaco-fb", {abi = "hip", iree.gpu.target = #iree_gpu.target<arch = "gfx942", features = "", wgp = <compute =  fp64|fp32|fp16|int64|int32|int16|int8, storage =  b64|b32|b16|b8, subgroup =  shuffle|arithmetic, dot =  dp4xi8toi32, mma = [<MFMA_F32_16x16x4_F32>, <MFMA_F32_16x16x16_F16>, <MFMA_F32_32x32x8_F16>, <MFMA_F64_16x16x4_F64>, <MFMA_F32_16x16x16_BF16>, <MFMA_F32_32x32x8_BF16>, <MFMA_F32_16x16x32_F8E5M2FNUZ>, <MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E4M3FNUZ_F8E5M2FNUZ>, <MFMA_F32_32x32x16_F8E5M2FNUZ>, <MFMA_F32_32x32x16_F8E5M2FNUZ_F8E4M3FNUZ>, <MFMA_F32_32x32x16_F8E4M3FNUZ>, <MFMA_F32_32x32x16_F8E4M3FNUZ_F8E5M2FNUZ>, <MFMA_I32_16x16x32_I8>, <MFMA_I32_32x32x16_I8>], subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536, max_workgroup_counts = [2147483647, 2147483647, 2147483647], max_load_instruction_bits = 128, simds_per_wgp = 4, vgpr_space_bits = 16384>>, ukernels = "none"}>) {
    hal.executable.export public @encode$async_dispatch_126_matmul_like_8x128x128x8_f16xf16xf32 ordinal(0) layout(#hal.pipeline.layout<bindings = [#hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) {
    ^bb0(%arg0: !hal.device loc("model.mlir":4760:12)):
      %x, %y, %z = flow.dispatch.workgroup_count_from_slice  loc("model.mlir":4760:12)
      hal.return %x, %y, %z : index, index, index loc("model.mlir":4760:12)
    } loc("model.mlir":4760:12)
    builtin.module {
      func.func @encode$async_dispatch_126_matmul_like_8x128x128x8_f16xf16xf32() {
        %cst = arith.constant 0.000000e+00 : f32 loc(unknown)
        %c162683648 = arith.constant 162683648 : index loc("model.mlir":4760:12)
        %c162683584 = arith.constant 162683584 : index loc("model.mlir":4760:12)
        %c1048576 = arith.constant 1048576 : index loc("model.mlir":4741:12)
        %c0 = arith.constant 0 : index loc("model.mlir":4760:12)
        %0 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(0) alignment(64) offset(%c162683648) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<16x8xf16>> loc("model.mlir":4760:12)
        %1 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(1) alignment(64) offset(%c1048576) flags("ReadOnly|Indirect") : !flow.dispatch.tensor<readonly:tensor<8x128x128xf16>> loc("model.mlir":4741:12)
        %2 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(0) alignment(64) offset(%c162683584) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<8xf32>> loc("model.mlir":4760:12)
        %3 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(2) alignment(64) offset(%c0) flags(Indirect) : !flow.dispatch.tensor<writeonly:tensor<8x128x128xf16>> loc("model.mlir":4760:12)
        %4 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0], sizes = [8, 128, 128], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<8x128x128xf16>> -> tensor<8x128x128xf16> loc("model.mlir":4760:12)
        %5 = flow.dispatch.tensor.load %2, offsets = [0], sizes = [8], strides = [1] : !flow.dispatch.tensor<readonly:tensor<8xf32>> -> tensor<8xf32> loc("model.mlir":4760:12)
        %6 = tensor.empty() : tensor<8x128x128xf16> loc("model.mlir":4741:12)
        %7 = tensor.empty() : tensor<8x128x128xf32> loc("model.mlir":4760:12)
        %8 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [8, 8], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<16x8xf16>> -> tensor<8x8xf16> loc("model.mlir":4760:12)
        %9 = linalg.fill ins(%cst : f32) outs(%7 : tensor<8x128x128xf32>) -> tensor<8x128x128xf32> loc("model.mlir":4760:12)
        %10 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d3, d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%4, %8 : tensor<8x128x128xf16>, tensor<8x8xf16>) outs(%9 : tensor<8x128x128xf32>) {
        ^bb0(%in: f16 loc(unknown), %in_0: f16 loc(unknown), %out: f32 loc(unknown)):
          %12 = arith.extf %in : f16 to f32 loc(unknown)
          %13 = arith.extf %in_0 : f16 to f32 loc(unknown)
          %14 = arith.mulf %12, %13 : f32 loc(unknown)
          %15 = arith.addf %out, %14 : f32 loc(unknown)
          linalg.yield %15 : f32 loc(unknown)
        } -> tensor<8x128x128xf32> loc("model.mlir":4760:12)
        %11 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%10, %5 : tensor<8x128x128xf32>, tensor<8xf32>) outs(%6 : tensor<8x128x128xf16>) {
        ^bb0(%in: f32 loc("model.mlir":4760:12), %in_0: f32 loc("model.mlir":4760:12), %out: f16 loc("model.mlir":4760:12)):
          %12 = arith.addf %in, %in_0 : f32 loc("model.mlir":4760:12)
          %13 = arith.truncf %12 : f32 to f16 loc("model.mlir":4760:12)
          linalg.yield %13 : f16 loc("model.mlir":4760:12)
        } -> tensor<8x128x128xf16> loc("model.mlir":4760:12)
        flow.dispatch.tensor.store %11, %3, offsets = [0, 0, 0], sizes = [8, 128, 128], strides = [1, 1, 1] : tensor<8x128x128xf16> -> !flow.dispatch.tensor<writeonly:tensor<8x128x128xf16>> loc("model.mlir":4760:12)
        return loc("model.mlir":4760:12)
      } loc("model.mlir":4760:12)
    } loc("model.mlir":4760:12)
  } loc("model.mlir":4760:12)
} loc("model.mlir":4760:12)

What component(s) does this issue relate to?

No response

Version information

No response

Additional context

No response

nirvedhmeshram commented 1 day ago

You can directly compile the dispatch with the following command

iree-compile --iree-hal-target-backends=rocm --iree-hip-target=gfx942 --compile-from=executable-sources input_ir.mlir  \
--mlir-print-ir-after-all &> output_dump.mlir

Also it would be helpful to include dispatch specific IR dumps to issues when applicable. Here is the dump for this one.

@MaheshRavishankar @qedawkins This is another case of the same thing I have observed with the ONNX models too, not aligned to instrinsic GEMMs with fused elementwise go down LLVMGPUSIMT pipeline and then blow up shared memory usage.

task 2 in this issue should solve this https://github.com/iree-org/iree/issues/19121

MaheshRavishankar commented 1 day ago

You can directly compile the dispatch with the following command

iree-compile --iree-hal-target-backends=rocm --iree-hip-target=gfx942 --compile-from=executable-sources input_ir.mlir  \
--mlir-print-ir-after-all &> output_dump.mlir

Also it would be helpful to include dispatch specific IR dumps to issues when applicable. Here is the dump for this one.

@MaheshRavishankar @qedawkins This is another case of the same thing I have observed with the ONNX models too, not aligned to instrinsic GEMMs with fused elementwise go down LLVMGPUSIMT pipeline and then blow up shared memory usage.

task 2 in this issue should solve this #19121

Interesting. This is a bit blocking though. It used to work before and now it isnt...

nithinsubbiah commented 14 hours ago

Failure happened after bubbling up extract_slice in linalg (https://github.com/iree-org/iree/pull/19174). The CI didn't catch the failure because of outdated IR