iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.85k stars 614 forks source link

iree_linalg_ext.attention dispatch fails to bufferize after TileAndDecomposeAttention #16421

Open monorimet opened 9 months ago

monorimet commented 9 months ago

What happened?

We have IR for SDXL that preserves the torch.aten._scaled_dot_product_attention op as tm_tensor.attention -> iree_linalg_ext.attention, but I'm seeing issues trying to lower the resulting IR through IREE-compiler.

@rsuderman managed to narrow the issue down to comprehensive bufferization of the tiled and decomposed attention op.

It seems that a buffer/address is being reused many times in the inner loop, causing bufferization/analysis to fail, though the error message given only mentions a mismatched yield / iterArg pair.

The minimized IR (minimal_attn.mlir):

func.func @main(%arg0 : tensor<20x4096x64xf16>, %arg1 : tensor<20x4096x64xf16>, %arg2 : tensor<20x4096x64xf16>, %arg3 : tensor<20x4096x64xf16>) -> tensor<20x4096x64xf16> {
  %empty = tensor.empty() : tensor<20x4096x64xf16>
  %0 = iree_linalg_ext.attention ins(%arg0, %arg1, %arg2 : tensor<20x4096x64xf16>, tensor<20x4096x64xf16>, tensor<20x4096x64xf16>) outs(%empty : tensor<20x4096x64xf16>) -> tensor<20x4096x64xf16>
  return %0 : tensor<20x4096x64xf16>
}

The CLI input:

iree-compile ./minimal_attn.mlir --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=llvm-cpu --iree-llvmcpu-target-cpu-features=host --mlir-print-op-on-diagnostic=false --iree-vm-bytecode-module-strip-source-map=true --iree-util-zero-fill-elided-attrs --iree-opt-strip-assertions=true --verify=true -o minimal_attn_cpu.vmfb --mlir-print-ir-after-all 2> out.txt

The error message:

./minimal_attn.mlir:3:8: error: Yield operand #1 is not equivalent to the corresponding iter bbArg
  %0 = iree_linalg_ext.attention ins(%arg0, %arg1, %arg2 : tensor<20x4096x64xf16>, tensor<20x4096x64xf16>, tensor<20x4096x64xf16>) outs(%empty : tensor<20x4096x64xf16>) -> 
tensor<20x4096x64xf16>
       ^
./minimal_attn.mlir:1:1: note: called from
func.func @main(%arg0 : tensor<20x4096x64xf16>, %arg1 : tensor<20x4096x64xf16>, %arg2 : tensor<20x4096x64xf16>, %arg3 : tensor<20x4096x64xf16>) -> tensor<20x4096x64xf16> {
^
./minimal_attn.mlir:3:8: error: failed to run translation of source executable to target executable for backend #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64", {cpu = 
"znver3", cpu_features = "+prfchw,-cldemote,+avx,+aes,+sahf,+pclmul,-xop,+crc32,+xsaves,-avx512fp16,-usermsr,-sm4,+sse4.1,-avx512ifma,+xsave,-avx512pf,+sse4.2,-tsxldtrk,-ptwrite,-wid
ekl,-sm3,+invpcid,+64bit,+xsavec,-avx10.1-512,-avx512vpopcntdq,+cmov,-avx512vp2intersect,-avx512cd,+movbe,-avxvnniint8,-avx512er,-amx-int8,-kl,-avx10.1-256,-sha512,-avxvnni,-rtm,+adx
,+avx2,-hreset,-movdiri,-serialize,+vpclmulqdq,-avx512vl,-uintr,+clflushopt,-raoint,-cmpccxadd,+bmi,-amx-tile,+sse,-gfni,-avxvnniint16,-amx-fp16,+xsaveopt,+rdrnd,-avx512f,-amx-bf16,-
avx512bf16,-avx512vnni,+cx8,-avx512bw,+sse3,-pku,+fsgsbase,+clzero,+mwaitx,-lwp,+lzcnt,+sha,-movdir64b,+wbnoinvd,-enqcmd,-prefetchwt1,-avxneconvert,-tbm,-pconfig,-amx-complex,+ssse3,
+cx16,+bmi2,+fma,+popcnt,-avxifma,+f16c,-avx512bitalg,+rdpru,+clwb,+mmx,+sse2,+rdseed,-avx512vbmi2,-prefetchi,+rdpid,-fma4,-avx512vbmi,+shstk,+vaes,-waitpkg,-sgx,+fxsr,-avx512dq,+sse
4a", data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128", native_vector_size = 32 : index, target_triple = 
"x86_64-unknown-unknown-eabi-elf", ukernels = "default"}>
  %0 = iree_linalg_ext.attention ins(%arg0, %arg1, %arg2 : tensor<20x4096x64xf16>, tensor<20x4096x64xf16>, tensor<20x4096x64xf16>) outs(%empty : tensor<20x4096x64xf16>) -> 
tensor<20x4096x64xf16>
       ^
./minimal_attn.mlir:1:1: note: called from
func.func @main(%arg0 : tensor<20x4096x64xf16>, %arg1 : tensor<20x4096x64xf16>, %arg2 : tensor<20x4096x64xf16>, %arg3 : tensor<20x4096x64xf16>) -> tensor<20x4096x64xf16> {
^
./minimal_attn.mlir:3:8: error: failed to translate executables
  %0 = iree_linalg_ext.attention ins(%arg0, %arg1, %arg2 : tensor<20x4096x64xf16>, tensor<20x4096x64xf16>, tensor<20x4096x64xf16>) outs(%empty : tensor<20x4096x64xf16>) -> 
tensor<20x4096x64xf16>
       ^
./minimal_attn.mlir:1:1: note: called from
func.func @main(%arg0 : tensor<20x4096x64xf16>, %arg1 : tensor<20x4096x64xf16>, %arg2 : tensor<20x4096x64xf16>, %arg3 : tensor<20x4096x64xf16>) -> tensor<20x4096x64xf16> {
^

The full log with --mlir-print-ir-after-all:

out.txt

What component(s) does this issue relate to?

iree-compiler, one-shot bufferization/analysis

Version information

Reproduced on latest IREE (0c61f77) and on a source build from contents of this PR: https://github.com/openxla/iree/pull/16416

Additional context

No response

monorimet commented 9 months ago

The attached out.txt shows that TileAndDecomposeAttention was the last successful pass, and EliminateEmptyTensors fails first.

The IR after TileAndDecomposeAttention shows an address being reused many times in the innermost decomposed block of computation, which might be triggering a failure of OneShotAnalysis:

// -----// IR Dump After TileAndDecomposeWinogradTransform (iree-linalg-ext-tile-and-decompose-winograd) //----- //
func.func @main_dispatch_0_attention_20x4096x64xf16() {
  %c4096 = arith.constant 4096 : index
  %c20 = arith.constant 20 : index
  %c0 = arith.constant 0 : index
  %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<20x4096x64xf16>>
  %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<20x4096x64xf16>>
  %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<20x4096x64xf16>>
  %3 = hal.interface.binding.subspan set(0) binding(3) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<20x4096x64xf16>>
  %workgroup_id_x = hal.interface.workgroup.id[0] : index
  %workgroup_count_x = hal.interface.workgroup.count[0] : index
  %workgroup_id_y = hal.interface.workgroup.id[1] : index
  %workgroup_count_y = hal.interface.workgroup.count[1] : index
  %4 = affine.apply affine_map<()[s0] -> (s0 * 20)>()[%workgroup_id_y]
  %5 = affine.apply affine_map<()[s0] -> (s0 * 20)>()[%workgroup_count_y]
  scf.for %arg0 = %4 to %c20 step %5 {
    %6 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_id_x]
    %7 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_count_x]
    scf.for %arg1 = %6 to %c4096 step %7 {
      %8 = flow.dispatch.tensor.load %3, offsets = [%arg0, %arg1, 0], sizes = [20, 64, 64], strides = [1, 1, 1] : !flow.dispatch.tensor<writeonly:tensor<20x4096x64xf16>> -> 
tensor<20x64x64xf16>
      %9 = flow.dispatch.tensor.load %0, offsets = [%arg0, %arg1, 0], sizes = [20, 64, 64], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<20x4096x64xf16>> -> 
tensor<20x64x64xf16>
      %10 = flow.dispatch.tensor.load %1, offsets = [%arg0, 0, 0], sizes = [20, 4096, 64], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<20x4096x64xf16>> -> 
tensor<20x4096x64xf16>
      %11 = flow.dispatch.tensor.load %2, offsets = [%arg0, 0, 0], sizes = [20, 4096, 64], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<20x4096x64xf16>> -> 
tensor<20x4096x64xf16>
      %12 = tensor.empty() : tensor<64x64xf32>
      %extracted_slice = tensor.extract_slice %8[0, 0, 0] [1, 64, 64] [1, 1, 1] : tensor<20x64x64xf16> to tensor<64x64xf16>
      %cst = arith.constant 0.000000e+00 : f32
      %13 = linalg.fill ins(%cst : f32) outs(%12 : tensor<64x64xf32>) -> tensor<64x64xf32>
      %cst_0 = arith.constant -1.000000e+30 : f32
      %14 = tensor.empty() : tensor<64xf32>
      %15 = linalg.fill ins(%cst_0 : f32) outs(%14 : tensor<64xf32>) -> tensor<64xf32>
      %16 = tensor.empty() : tensor<64xf32>
      %17 = linalg.fill ins(%cst : f32) outs(%16 : tensor<64xf32>) -> tensor<64xf32>
      %c0_1 = arith.constant 0 : index
      %c4096_2 = arith.constant 4096 : index
      %c64 = arith.constant 64 : index
      %18:3 = scf.for %arg2 = %c0_1 to %c4096_2 step %c64 iter_args(%arg3 = %13, %arg4 = %15, %arg5 = %17) -> (tensor<64x64xf32>, tensor<64xf32>, tensor<64xf32>) {
        %extracted_slice_3 = tensor.extract_slice %10[0, %arg2, 0] [1, 64, 64] [1, 1, 1] : tensor<20x4096x64xf16> to tensor<64x64xf16>
        %extracted_slice_4 = tensor.extract_slice %11[0, %arg2, 0] [1, 64, 64] [1, 1, 1] : tensor<20x4096x64xf16> to tensor<64x64xf16>
        %extracted_slice_5 = tensor.extract_slice %9[0, 0, 0] [1, 64, 64] [1, 1, 1] : tensor<20x64x64xf16> to tensor<64x64xf16>
        %cst_6 = arith.constant 0.000000e+00 : f32
        %21 = tensor.empty() : tensor<64x64xf32>
        %22 = linalg.fill ins(%cst_6 : f32) outs(%21 : tensor<64x64xf32>) -> tensor<64x64xf32>
        %23 = linalg.matmul_transpose_b ins(%extracted_slice_5, %extracted_slice_3 : tensor<64x64xf16>, tensor<64x64xf16>) outs(%22 : tensor<64x64xf32>) -> tensor<64x64xf32>
        %24 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%23 : 
tensor<64x64xf32>) outs(%arg4 : tensor<64xf32>) {
        ^bb0(%in: f32, %out: f32):
          %33 = arith.maximumf %in, %out : f32
          linalg.yield %33 : f32
        } -> tensor<64xf32>
        %25 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%24 : tensor<64xf32>) 
outs(%23 : tensor<64x64xf32>) {
        ^bb0(%in: f32, %out: f32):
          %33 = arith.subf %out, %in : f32
          %34 = math.exp2 %33 : f32
          linalg.yield %34 : f32
        } -> tensor<64x64xf32>
        %26 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%24 : tensor<64xf32>) outs(%arg4 : 
tensor<64xf32>) {
        ^bb0(%in: f32, %out: f32):
          %33 = arith.subf %out, %in : f32
          %34 = math.exp2 %33 : f32
          linalg.yield %34 : f32
        } -> tensor<64xf32>
        %27 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%26 : tensor<64xf32>) outs(%arg5 : 
tensor<64xf32>) {
        ^bb0(%in: f32, %out: f32):
          %33 = arith.mulf %in, %out : f32
          linalg.yield %33 : f32
        } -> tensor<64xf32>
        %28 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%25 : 
tensor<64x64xf32>) outs(%27 : tensor<64xf32>) {
        ^bb0(%in: f32, %out: f32):
          %33 = arith.addf %in, %out : f32
          linalg.yield %33 : f32
        } -> tensor<64xf32>
        %29 = tensor.empty() : tensor<64x64xf16>
        %30 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%25 : 
tensor<64x64xf32>) outs(%29 : tensor<64x64xf16>) {
        ^bb0(%in: f32, %out: f16):
          %33 = arith.truncf %in : f32 to f16
          linalg.yield %33 : f16
        } -> tensor<64x64xf16>
        %31 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%26 : tensor<64xf32>) 
outs(%arg3 : tensor<64x64xf32>) {
        ^bb0(%in: f32, %out: f32):
          %33 = arith.mulf %in, %out : f32
          linalg.yield %33 : f32
        } -> tensor<64x64xf32>
        %32 = linalg.matmul ins(%30, %extracted_slice_4 : tensor<64x64xf16>, tensor<64x64xf16>) outs(%31 : tensor<64x64xf32>) -> tensor<64x64xf32>
        scf.yield %32, %24, %28 : tensor<64x64xf32>, tensor<64xf32>, tensor<64xf32>
      }
      %19 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%18#2 : tensor<64xf32>) 
outs(%18#0 : tensor<64x64xf32>) {
      ^bb0(%in: f32, %out: f32):
        %cst_3 = arith.constant 1.000000e+00 : f32
        %21 = arith.divf %cst_3, %in : f32
        %22 = arith.mulf %21, %out : f32
        linalg.yield %22 : f32
      } -> tensor<64x64xf32>
      %20 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%19 : 
tensor<64x64xf32>) outs(%extracted_slice : tensor<64x64xf16>) {
      ^bb0(%in: f32, %out: f16):
        %21 = arith.truncf %in : f32 to f16
        linalg.yield %21 : f16
      } -> tensor<64x64xf16>
      %inserted_slice = tensor.insert_slice %20 into %8[0, 0, 0] [1, 64, 64] [1, 1, 1] : tensor<64x64xf16> into tensor<20x64x64xf16>
      flow.dispatch.tensor.store %inserted_slice, %3, offsets = [%arg0, %arg1, 0], sizes = [20, 64, 64], strides = [1, 1, 1] : tensor<20x64x64xf16> -> 
!flow.dispatch.tensor<writeonly:tensor<20x4096x64xf16>>
    }
  }
  return
}
MaheshRavishankar commented 9 months ago

Turns out attention decomposition is failing during bufferization and we probably dont have any end-to-end tests exercising this (AFAIK all attention op related work has been done only using transform dialect with correctness checks out of tree). So this is probably failing silently. I disabled the decomposition for now. This will unblock other work, while I look into this issue (most like file upstream bug).

MaheshRavishankar commented 9 months ago

I am attaching the current decomposition (repro.mlir) and what I manually fixed (fixed.mlir). The latter fixes the failure in the eliminate empty tensors pass. The current decomposition is doing something funky with destinations and probably needs a look. It could also probably use some elementwise op fusion to make the IR more managable. All in all this needs a new look.

Tagging @harsh-nod , do you think we can iterate on this to get to a better final place.

fixed.mlir.txt repro.mlir.txt

harsh-nod commented 9 months ago

@MaheshRavishankar - definitely. I will be out next week but @erman-gurses can help pick this up in the meanwhile. Some notes on the problem:

  1. I just tried top of master (with pass uncommented) and did not see the tensor.casts (here is my output: https://gist.github.com/harsh-nod/c8a2a9da723dc5deea77fa6f8d5ff229).
  2. I can reproduce the error though and it does fail in eliminate empty tensors.
  3. The error has to do with the use of (%arg4 in the above gist and I can see you fixed that manually by adding a tensor.empty). I can comment a little bit as to why arg4 is used the way it is. arg4 represents the initial value of the accumulator for the max value (used for softmax). You can see this here
    %24 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%23 : tensor<64x64xf32>) outs(%arg4 : tensor<64xf32>) {
        ^bb0(%in: f32, %out: f32):
          %33 = arith.maximumf %in, %out : f32
          linalg.yield %33 : f32
        } -> tensor<64xf32>

    where we use the old max to compute the new max (%24). This is returned at the end of the loop

    scf.yield %32, %24, %28 : tensor<64x64xf32>, tensor<64xf32>, tensor<64xf32>

    The problem comes when we use this again here

    %26 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%24 : tensor<64xf32>) outs(%arg4 : tensor<64xf32>) {
        ^bb0(%in: f32, %out: f32):
          %33 = arith.subf %out, %in : f32
          %34 = math.exp2 %33 : f32
          linalg.yield %34 : f32
        } -> tensor<64xf32>

    The idea behind this second usage was the following: Say we allocate some memory for %arg4, then if we allocate some additional memory for %24, we want to reuse the memory we allocated for %arg4 when we do the second generic (rather than allocate some more memory). The idea was to minimize the memory allocations inside the for loop with the assumption that the tensor.empty() might materialize into an allocation.

Another note is that the transform dialect bufferization operator seems to bufferize the original graph. So might be worthwhile to look into the differences between the TD Bufferization Op and what's in the pass pipeline. Specifically, here https://github.com/openxla/iree/blob/6560f8600c6e6a79eaed3ca3b26af2f8e2b900ae/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp#L760 where some new passes are being used such as EmptyTensorLoweringPattern etc.

MaheshRavishankar commented 9 months ago

First thing to look at the output of this command and see if it avoids additional stack allocations

iree-opt --iree-eliminate-empty-tensors --empty-tensor-to-alloc-tensor --iree-codegen-iree-comprehensive-bufferize fixed.mlir

i.e. without trying to manually reuse the buffers, but letting the analysis eliminate it for you. If so, then just changing the tile and decompose attention pass to not reuse buffer this way would fix the issue.

monorimet commented 8 months ago

@MaheshRavishankar FWIW:

Followed instructions to run iree-opt --iree-eliminate-empty-tensors --empty-tensor-to-alloc-tensor --iree-codegen-iree-comprehensive-bufferize fixed.mlir

output of iree-opt (above command): minimal_attn_elim.mlir.txt

iree-compile .\minimal_attn_elim.mlir --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=llvm-cpu --iree-llvmcpu-embedded-linker-path=C:\V\iree\build\compiler\bindings\python\iree\compiler\tools\..\_mlir_libs\iree-lld.exe --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --mlir-pass-pipeline-crash-reproducer=./shark_tmp/core-reproducer.mlir --iree-input-type=torch --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host --iree-llvmcpu-target-triple=x86_64-linux-gnu --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 --iree-llvmcpu-enable-ukernels=all -o minimal_attn.vmfb
Assertion failed: index < size() && "invalid index into type range", file C:\V\iree\third_party\llvm-project\mlir\include\mlir/IR/TypeRange.h, line 140
.\minimal_attn_elim.mlir:6:1: error: Failures have been detected while processing an MLIR pass pipeline
module {
^
.\minimal_attn_elim.mlir:6:1: note: Pipeline failed while executing [`FormDispatchRegions` on 'util.func' operation: @main_dispatch_0_attention_20x4096x64xf16]: reproducer generated at `./shark_tmp/core-reproducer.mlir`
monorimet commented 8 months ago

noting I've included the empty-tensor-to-alloc-tensor pass ad-hoc -- this is my diff to init_mlir_passes.h:

diff --git a/compiler/src/iree/compiler/Tools/init_mlir_passes.h b/compiler/src/iree/compiler/Tools/init_mlir_passes.h
index 4ebbdafe0..2d1d17897 100644
--- a/compiler/src/iree/compiler/Tools/init_mlir_passes.h
+++ b/compiler/src/iree/compiler/Tools/init_mlir_passes.h
@@ -19,6 +19,7 @@
 #include "mlir/Dialect/ArmSME/Transforms/Passes.h"
 #include "mlir/Dialect/GPU/Transforms/Passes.h"
 #include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
 #include "mlir/Dialect/MemRef/Transforms/Passes.h"
 #include "mlir/Dialect/SCF/Transforms/Passes.h"
 #include "mlir/Dialect/SPIRV/Transforms/Passes.h"
@@ -59,6 +60,9 @@ inline void registerMlirPasses() {
   // Arm SME
   arm_sme::registerArmSMEPasses();

+  // Bufferization
+  bufferization::registerEmptyTensorToAllocTensor();
+
   // Linalg
   registerLinalgPasses();
hanhanW commented 8 months ago

I just verified that fixed.mlir can pass bufferization, though there are some big buffers. With https://github.com/openxla/iree/pull/16524, we generate the below IRs. There are some stack allocation, which is bounded by 128 and 64. It is okay if they are from tile sizes. Without the fix (i.e., using repro.mlir.txt), I'm able to see the error. So we probably want to fix the decomposition logic.

#map = affine_map<()[s0] -> (s0 * 128)>
#map1 = affine_map<(d0, d1) -> (d0, d1)>
#map2 = affine_map<(d0, d1) -> (d0)>
#map3 = affine_map<(d0) -> (d0)>
module {
  func.func @main_dispatch_0_attention_20x4096x64xf16() {
    %c0 = arith.constant 0 : index
    %cst = arith.constant 0.000000e+00 : f32
    %cst_0 = arith.constant -1.000000e+30 : f32
    %c4096 = arith.constant 4096 : index
    %cst_1 = arith.constant 1.000000e+00 : f32
    %c128 = arith.constant 128 : index
    %alloca = memref.alloca() {alignment = 64 : i64} : memref<128x128xf16>
    %alloca_2 = memref.alloca() {alignment = 64 : i64} : memref<128xf32>
    %alloca_3 = memref.alloca() {alignment = 64 : i64} : memref<128x128xf32>
    %alloca_4 = memref.alloca() {alignment = 64 : i64} : memref<128xf32>
    %alloca_5 = memref.alloca() {alignment = 64 : i64} : memref<128xf32>
    %alloca_6 = memref.alloca() {alignment = 64 : i64} : memref<128x64xf32>
    %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<20x4096x64xf16, #hal.descriptor_type<storage_buffer>>
    memref.assume_alignment %0, 64 : memref<20x4096x64xf16, #hal.descriptor_type<storage_buffer>>
    %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<20x4096x64xf16, #hal.descriptor_type<storage_buffer>>
    memref.assume_alignment %1, 64 : memref<20x4096x64xf16, #hal.descriptor_type<storage_buffer>>
    %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<20x4096x64xf16, #hal.descriptor_type<storage_buffer>>
    memref.assume_alignment %2, 64 : memref<20x4096x64xf16, #hal.descriptor_type<storage_buffer>>
    %3 = hal.interface.binding.subspan set(0) binding(3) type(storage_buffer) alignment(64) offset(%c0) : memref<20x4096x64xf16, #hal.descriptor_type<storage_buffer>>
    memref.assume_alignment %3, 64 : memref<20x4096x64xf16, #hal.descriptor_type<storage_buffer>>
    %workgroup_id_y = hal.interface.workgroup.id[1] : index
    %workgroup_id_x = hal.interface.workgroup.id[0] : index
    %4 = affine.apply #map()[%workgroup_id_x]
    %subview = memref.subview %3[%workgroup_id_y, %4, 0] [1, 128, 64] [1, 1, 1] : memref<20x4096x64xf16, #hal.descriptor_type<storage_buffer>> to memref<1x128x64xf16, strided<[262144, 64, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
    %subview_7 = memref.subview %0[%workgroup_id_y, %4, 0] [1, 128, 64] [1, 1, 1] : memref<20x4096x64xf16, #hal.descriptor_type<storage_buffer>> to memref<1x128x64xf16, strided<[262144, 64, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
    %subview_8 = memref.subview %1[%workgroup_id_y, 0, 0] [1, 4096, 64] [1, 1, 1] : memref<20x4096x64xf16, #hal.descriptor_type<storage_buffer>> to memref<1x4096x64xf16, strided<[262144, 64, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
    %subview_9 = memref.subview %2[%workgroup_id_y, 0, 0] [1, 4096, 64] [1, 1, 1] : memref<20x4096x64xf16, #hal.descriptor_type<storage_buffer>> to memref<1x4096x64xf16, strided<[262144, 64, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
    %subview_10 = memref.subview %subview[0, 0, 0] [1, 128, 64] [1, 1, 1] : memref<1x128x64xf16, strided<[262144, 64, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<128x64xf16, strided<[64, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
    linalg.fill ins(%cst : f32) outs(%alloca_6 : memref<128x64xf32>)
    linalg.fill ins(%cst_0 : f32) outs(%alloca_4 : memref<128xf32>)
    linalg.fill ins(%cst : f32) outs(%alloca_5 : memref<128xf32>)
    scf.for %arg0 = %c0 to %c4096 step %c128 {
      %subview_11 = memref.subview %subview_8[0, %arg0, 0] [1, 128, 64] [1, 1, 1] : memref<1x4096x64xf16, strided<[262144, 64, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<128x64xf16, strided<[64, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
      linalg.fill ins(%cst : f32) outs(%alloca_3 : memref<128x128xf32>)
      %subview_12 = memref.subview %subview_7[0, 0, 0] [1, 128, 64] [1, 1, 1] : memref<1x128x64xf16, strided<[262144, 64, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<128x64xf16, strided<[64, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
      linalg.matmul_transpose_b ins(%subview_12, %subview_11 : memref<128x64xf16, strided<[64, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, memref<128x64xf16, strided<[64, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>) outs(%alloca_3 : memref<128x128xf32>)
      linalg.generic {indexing_maps = [#map1, #map2], iterator_types = ["parallel", "reduction"]} ins(%alloca_3 : memref<128x128xf32>) outs(%alloca_4 : memref<128xf32>) {
      ^bb0(%in: f32, %out: f32):
        %5 = arith.maximumf %in, %out : f32
        linalg.yield %5 : f32
      }
      linalg.generic {indexing_maps = [#map2, #map1], iterator_types = ["parallel", "parallel"]} ins(%alloca_4 : memref<128xf32>) outs(%alloca_3 : memref<128x128xf32>) {
      ^bb0(%in: f32, %out: f32):
        %5 = arith.subf %out, %in : f32
        %6 = math.exp2 %5 : f32
        linalg.yield %6 : f32
      }
      linalg.generic {indexing_maps = [#map3, #map3], iterator_types = ["parallel"]} ins(%alloca_4 : memref<128xf32>) outs(%alloca_2 : memref<128xf32>) {
      ^bb0(%in: f32, %out: f32):
        %5 = arith.subf %out, %in : f32
        %6 = math.exp2 %5 : f32
        linalg.yield %6 : f32
      }
      linalg.generic {indexing_maps = [#map3, #map3], iterator_types = ["parallel"]} ins(%alloca_2 : memref<128xf32>) outs(%alloca_5 : memref<128xf32>) {
      ^bb0(%in: f32, %out: f32):
        %5 = arith.mulf %in, %out : f32
        linalg.yield %5 : f32
      }
      linalg.generic {indexing_maps = [#map1, #map2], iterator_types = ["parallel", "reduction"]} ins(%alloca_3 : memref<128x128xf32>) outs(%alloca_5 : memref<128xf32>) {
      ^bb0(%in: f32, %out: f32):
        %5 = arith.addf %in, %out : f32
        linalg.yield %5 : f32
      }
      linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%alloca_3 : memref<128x128xf32>) outs(%alloca : memref<128x128xf16>) {
      ^bb0(%in: f32, %out: f16):
        %5 = arith.truncf %in : f32 to f16
        linalg.yield %5 : f16
      }
      linalg.generic {indexing_maps = [#map2, #map1], iterator_types = ["parallel", "parallel"]} ins(%alloca_2 : memref<128xf32>) outs(%alloca_6 : memref<128x64xf32>) {
      ^bb0(%in: f32, %out: f32):
        %5 = arith.mulf %in, %out : f32
        linalg.yield %5 : f32
      }
      %subview_13 = memref.subview %subview_9[0, %arg0, 0] [1, 128, 64] [1, 1, 1] : memref<1x4096x64xf16, strided<[262144, 64, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<128x64xf16, strided<[64, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
      linalg.matmul ins(%alloca, %subview_13 : memref<128x128xf16>, memref<128x64xf16, strided<[64, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>) outs(%alloca_6 : memref<128x64xf32>)
    }
    linalg.generic {indexing_maps = [#map2, #map1], iterator_types = ["parallel", "parallel"]} ins(%alloca_5 : memref<128xf32>) outs(%alloca_6 : memref<128x64xf32>) {
    ^bb0(%in: f32, %out: f32):
      %5 = arith.divf %cst_1, %in : f32
      %6 = arith.mulf %5, %out : f32
      linalg.yield %6 : f32
    }
    linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%alloca_6 : memref<128x64xf32>) outs(%subview_10 : memref<128x64xf16, strided<[64, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>) {
    ^bb0(%in: f32, %out: f16):
      %5 = arith.truncf %in : f32 to f16
      linalg.yield %5 : f16
    }
    return
  }
}
hanhanW commented 8 months ago

I have a fix (which generates the IR that @MaheshRavishankar suggested) for attention on CPU side. Let me prepare a PR.

hanhanW commented 8 months ago

There are big stack allocation issues, but you can bypass it with --iree-llvmcpu-fail-on-out-of-bounds-stack-allocation=false. Otherwise, you'll see the below error. I think they are bounded by distribution tile sizes (which is tile_sizes = [[20, 64]] in this case), so it is fine for now.

/Users/hanchung/z.mlir:3:8: error: 'func.func' op exceeded stack allocation limit of 32768 bytes for function. Got 41728 bytes
  %0 = iree_linalg_ext.attention ins(%arg0, %arg1, %arg2 : tensor<20x4096x64xf16>, tensor<20x4096x64xf16>, tensor<20x4096x64xf16>) outs(%empty : tensor<20x4096x64xf16>) -> tensor<20x4096x64xf16>
...
  %24 = "memref.alloca"() <{alignment = 64 : i64, operandSegmentSizes = array<i32: 0, 0>}> : () -> memref<64x64xf16>
  %25 = "memref.alloca"() <{alignment = 64 : i64, operandSegmentSizes = array<i32: 0, 0>}> : () -> memref<64xf32>
  %26 = "memref.alloca"() <{alignment = 64 : i64, operandSegmentSizes = array<i32: 0, 0>}> : () -> memref<64x64xf32>
  %27 = "memref.alloca"() <{alignment = 64 : i64, operandSegmentSizes = array<i32: 0, 0>}> : () -> memref<64xf32>
  %28 = "memref.alloca"() <{alignment = 64 : i64, operandSegmentSizes = array<i32: 0, 0>}> : () -> memref<64xf32>
  %29 = "memref.alloca"() <{alignment = 64 : i64, operandSegmentSizes = array<i32: 0, 0>}> : () -> memref<64x64xf32>

I'm able to compile the op e2e on CPU. To repro, run iree-compile --output-format=vm-bytecode --iree-hal-target-backends=llvm-cpu --iree-llvmcpu-target-cpu=cascadelake --iree-llvmcpu-target-triple=x86_64-unknown-linux-gnu ~/z.mlir -o /tmp/a.vmfb --iree-llvmcpu-fail-on-out-of-bounds-stack-allocation=false.

MaheshRavishankar commented 8 months ago

There are big stack allocation issues, but you can bypass it with --iree-llvmcpu-fail-on-out-of-bounds-stack-allocation=false. Otherwise, you'll see the below error. I think they are bounded by distribution tile sizes (which is tile_sizes = [[20, 64]] in this case), so it is fine for now.

/Users/hanchung/z.mlir:3:8: error: 'func.func' op exceeded stack allocation limit of 32768 bytes for function. Got 41728 bytes
  %0 = iree_linalg_ext.attention ins(%arg0, %arg1, %arg2 : tensor<20x4096x64xf16>, tensor<20x4096x64xf16>, tensor<20x4096x64xf16>) outs(%empty : tensor<20x4096x64xf16>) -> tensor<20x4096x64xf16>
...
  %24 = "memref.alloca"() <{alignment = 64 : i64, operandSegmentSizes = array<i32: 0, 0>}> : () -> memref<64x64xf16>
  %25 = "memref.alloca"() <{alignment = 64 : i64, operandSegmentSizes = array<i32: 0, 0>}> : () -> memref<64xf32>
  %26 = "memref.alloca"() <{alignment = 64 : i64, operandSegmentSizes = array<i32: 0, 0>}> : () -> memref<64x64xf32>
  %27 = "memref.alloca"() <{alignment = 64 : i64, operandSegmentSizes = array<i32: 0, 0>}> : () -> memref<64xf32>
  %28 = "memref.alloca"() <{alignment = 64 : i64, operandSegmentSizes = array<i32: 0, 0>}> : () -> memref<64xf32>
  %29 = "memref.alloca"() <{alignment = 64 : i64, operandSegmentSizes = array<i32: 0, 0>}> : () -> memref<64x64xf32>

I'm able to compile the op e2e on CPU. To repro, run iree-compile --output-format=vm-bytecode --iree-hal-target-backends=llvm-cpu --iree-llvmcpu-target-cpu=cascadelake --iree-llvmcpu-target-triple=x86_64-unknown-linux-gnu ~/z.mlir -o /tmp/a.vmfb --iree-llvmcpu-fail-on-out-of-bounds-stack-allocation=false.

can we make it [[20, 32]] and it fits within the limit?

hanhanW commented 8 months ago

can we make it [[20, 32]] and it fits within the limit?

Yes, we can do it with --iree-llvmcpu-distribution-size=32 flag. I think we need a specialized setRootConfig entry function for the op (or all the LinalgExt ops). Because all of them go with CPUDefault pipeline, which only apply distribution and bufferization.

MaheshRavishankar commented 8 months ago

@monorimet https://github.com/openxla/iree/issues/16421#issuecomment-1958397924 should unblock running the model on CPU backends (using the flag to disable failure on exceeding stack allocation limit). Can you verify (though on GPU if it isnt within the limit, it wont run.

@erman-gurses after Hanhan's fixes, can you add an e2e test in IREE with attention op, so that we can catch this issue.

erman-gurses commented 8 months ago

@erman-gurses after Hanhan's fixes, can you add an e2e test in IREE with attention op, so that we can catch this issue. @MaheshRavishankar, Sure, I can work on that.

hanhanW commented 8 months ago

@erman-gurses after Hanhan's fixes, can you add an e2e test in IREE with attention op, so that we can catch this issue. @MaheshRavishankar, Sure, I can work on that.

Thanks, please add e2e tests to https://github.com/openxla/iree/tree/main/tests/e2e/linalg_ext_ops

monorimet commented 8 months ago

The following command successfully finishes compilation for me with these commits cherrypicked:

iree-compile .\minimal_attn.mlir --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=llvm-cpu --iree-llvmcpu-embedded-linker-path=C:\V\iree\build\compiler\bindings\python\iree\compiler\tools\..\_mlir_libs\iree-lld.exe --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --mlir-pass-pipeline-crash-reproducer=./shark_tmp/core-reproducer.mlir --iree-input-type=torch --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host --iree-llvmcpu-target-triple=x86_64-linux-gnu --iree-llvmcpu-enable-ukernels=all --iree-llvmcpu-distribution-size=32 -o minimal_attn.vmfb

Thank you @hanhanW

I will try on other backends as well

hanhanW commented 8 months ago

Opening the issue until we verify that it's fixed for other backends as well.

monorimet commented 8 months ago

ROCM runs into shared memory allocation limit with attention tiled+decomposed: https://github.com/openxla/iree/issues/16538

monorimet commented 8 months ago

Vulkan is tricky.

Stumbling around SPIRV KernelConfig it seems that we just don't have a good pipeline for this decomposition -- I haven't had any luck dropping in LinalgExt::TileAndDecomposeAttentionPass anywhere in addTileAndDistributeToWorkgroupsPasses to mimic the LLVMCPU/LLVMGPU implementation, mostly running into issues turning the result into valid MMA Subgroup Compute ops, perhaps it needs a different or slightly more bespoke pipeline to be set for these attention op dispatches... Very open to suggestions here.

@erman-gurses @Eliasj42 let me know if you two found anything workable for this (we will see it in vulkan + SDXL UNet, VAE)

hanhanW commented 8 months ago

It's a wrong fix, so it introduces numerical issues. I created a revert https://github.com/openxla/iree/pull/16559

I will figure how to fix it correctly.

hanhanW commented 8 months ago

I just realized that my fix is still wrong.. it does not consider the update max slice correctly. There are two issues about attention.

  1. We need to teach IREE bufferization about the decomposed ops.
  2. We need to implement LinalgExtToLoops for the op, so we can at least have scalar fallback solution always.

I'm gonna take a look at (1) and can probably implement (2). If anyone can help on (2), that would be great. I can point out where to add the code.

hanhanW commented 8 months ago

(2) makes sure that we have basic coverages for all the backends, including VMVX.

hanhanW commented 8 months ago

I noticed that we can pass bufferization if we vectorize all the operations. I think the final goal is to have vectorization working for all the dispatches. So I'd like to create a new pipeline for attention op on CPU side: https://github.com/openxla/iree/pull/16577

MaheshRavishankar commented 8 months ago

@harsh-nod looking at this

%0 = linalg.generic {indexing_maps = [#map2, #map1], iterator_types = ["parallel", "parallel"]} ins(%4 : tensor<128xf32>) outs(%3 : tensor<128x128xf32>) {
      ^bb0(%in: f32, %out: f32):
        %5 = arith.subf %out, %in : f32
        %6 = math.exp2 %5 : f32
        linalg.yield %6 : f32
      } -> tensor<128x128xf32>

These kinds of linalg.generic is kind of an anti-pattern. Ideally the %out value is never "read" from within the body when the iterators are marked as all parallel. It isnt wrong per-se, but such uses are better avoided. I understand why you are doing it. A better way to do this would be to use bufferization.to_tensor or similar operations to try to reuse memory.

pashu123 commented 8 months ago

I just realized that my fix is still wrong.. it does not consider the update max slice correctly. There are two issues about attention.

  1. We need to teach IREE bufferization about the decomposed ops.
  2. We need to implement LinalgExtToLoops for the op, so we can at least have scalar fallback solution always.

I'm gonna take a look at (1) and can probably implement (2). If anyone can help on (2), that would be great. I can point out where to add the code.

@hanhanW I can work on the scalar fallback solution. I think we have a similar implementation here: https://github.com/gpetters94/mlir-npcomp/blob/3e30bb06c0cd725a32f1091552d4824bd796a2d6/externals/llvm-external-projects/torch-mlir-dialects/lib/Dialect/TMTensor/IR/TMTensorOps.cpp#L179

hanhanW commented 8 months ago

@hanhanW I can work on the scalar fallback solution. I think we have a similar implementation here: https://github.com/gpetters94/mlir-npcomp/blob/3e30bb06c0cd725a32f1091552d4824bd796a2d6/externals/llvm-external-projects/torch-mlir-dialects/lib/Dialect/TMTensor/IR/TMTensorOps.cpp#L179

Thanks for offering the help, that's very helpful!