Open monorimet opened 9 months ago
The attached out.txt
shows that TileAndDecomposeAttention was the last successful pass, and EliminateEmptyTensors fails first.
The IR after TileAndDecomposeAttention shows an address being reused many times in the innermost decomposed block of computation, which might be triggering a failure of OneShotAnalysis:
// -----// IR Dump After TileAndDecomposeWinogradTransform (iree-linalg-ext-tile-and-decompose-winograd) //----- //
func.func @main_dispatch_0_attention_20x4096x64xf16() {
%c4096 = arith.constant 4096 : index
%c20 = arith.constant 20 : index
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<20x4096x64xf16>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<20x4096x64xf16>>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<20x4096x64xf16>>
%3 = hal.interface.binding.subspan set(0) binding(3) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<20x4096x64xf16>>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%workgroup_count_x = hal.interface.workgroup.count[0] : index
%workgroup_id_y = hal.interface.workgroup.id[1] : index
%workgroup_count_y = hal.interface.workgroup.count[1] : index
%4 = affine.apply affine_map<()[s0] -> (s0 * 20)>()[%workgroup_id_y]
%5 = affine.apply affine_map<()[s0] -> (s0 * 20)>()[%workgroup_count_y]
scf.for %arg0 = %4 to %c20 step %5 {
%6 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_id_x]
%7 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_count_x]
scf.for %arg1 = %6 to %c4096 step %7 {
%8 = flow.dispatch.tensor.load %3, offsets = [%arg0, %arg1, 0], sizes = [20, 64, 64], strides = [1, 1, 1] : !flow.dispatch.tensor<writeonly:tensor<20x4096x64xf16>> ->
tensor<20x64x64xf16>
%9 = flow.dispatch.tensor.load %0, offsets = [%arg0, %arg1, 0], sizes = [20, 64, 64], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<20x4096x64xf16>> ->
tensor<20x64x64xf16>
%10 = flow.dispatch.tensor.load %1, offsets = [%arg0, 0, 0], sizes = [20, 4096, 64], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<20x4096x64xf16>> ->
tensor<20x4096x64xf16>
%11 = flow.dispatch.tensor.load %2, offsets = [%arg0, 0, 0], sizes = [20, 4096, 64], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<20x4096x64xf16>> ->
tensor<20x4096x64xf16>
%12 = tensor.empty() : tensor<64x64xf32>
%extracted_slice = tensor.extract_slice %8[0, 0, 0] [1, 64, 64] [1, 1, 1] : tensor<20x64x64xf16> to tensor<64x64xf16>
%cst = arith.constant 0.000000e+00 : f32
%13 = linalg.fill ins(%cst : f32) outs(%12 : tensor<64x64xf32>) -> tensor<64x64xf32>
%cst_0 = arith.constant -1.000000e+30 : f32
%14 = tensor.empty() : tensor<64xf32>
%15 = linalg.fill ins(%cst_0 : f32) outs(%14 : tensor<64xf32>) -> tensor<64xf32>
%16 = tensor.empty() : tensor<64xf32>
%17 = linalg.fill ins(%cst : f32) outs(%16 : tensor<64xf32>) -> tensor<64xf32>
%c0_1 = arith.constant 0 : index
%c4096_2 = arith.constant 4096 : index
%c64 = arith.constant 64 : index
%18:3 = scf.for %arg2 = %c0_1 to %c4096_2 step %c64 iter_args(%arg3 = %13, %arg4 = %15, %arg5 = %17) -> (tensor<64x64xf32>, tensor<64xf32>, tensor<64xf32>) {
%extracted_slice_3 = tensor.extract_slice %10[0, %arg2, 0] [1, 64, 64] [1, 1, 1] : tensor<20x4096x64xf16> to tensor<64x64xf16>
%extracted_slice_4 = tensor.extract_slice %11[0, %arg2, 0] [1, 64, 64] [1, 1, 1] : tensor<20x4096x64xf16> to tensor<64x64xf16>
%extracted_slice_5 = tensor.extract_slice %9[0, 0, 0] [1, 64, 64] [1, 1, 1] : tensor<20x64x64xf16> to tensor<64x64xf16>
%cst_6 = arith.constant 0.000000e+00 : f32
%21 = tensor.empty() : tensor<64x64xf32>
%22 = linalg.fill ins(%cst_6 : f32) outs(%21 : tensor<64x64xf32>) -> tensor<64x64xf32>
%23 = linalg.matmul_transpose_b ins(%extracted_slice_5, %extracted_slice_3 : tensor<64x64xf16>, tensor<64x64xf16>) outs(%22 : tensor<64x64xf32>) -> tensor<64x64xf32>
%24 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%23 :
tensor<64x64xf32>) outs(%arg4 : tensor<64xf32>) {
^bb0(%in: f32, %out: f32):
%33 = arith.maximumf %in, %out : f32
linalg.yield %33 : f32
} -> tensor<64xf32>
%25 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%24 : tensor<64xf32>)
outs(%23 : tensor<64x64xf32>) {
^bb0(%in: f32, %out: f32):
%33 = arith.subf %out, %in : f32
%34 = math.exp2 %33 : f32
linalg.yield %34 : f32
} -> tensor<64x64xf32>
%26 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%24 : tensor<64xf32>) outs(%arg4 :
tensor<64xf32>) {
^bb0(%in: f32, %out: f32):
%33 = arith.subf %out, %in : f32
%34 = math.exp2 %33 : f32
linalg.yield %34 : f32
} -> tensor<64xf32>
%27 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%26 : tensor<64xf32>) outs(%arg5 :
tensor<64xf32>) {
^bb0(%in: f32, %out: f32):
%33 = arith.mulf %in, %out : f32
linalg.yield %33 : f32
} -> tensor<64xf32>
%28 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%25 :
tensor<64x64xf32>) outs(%27 : tensor<64xf32>) {
^bb0(%in: f32, %out: f32):
%33 = arith.addf %in, %out : f32
linalg.yield %33 : f32
} -> tensor<64xf32>
%29 = tensor.empty() : tensor<64x64xf16>
%30 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%25 :
tensor<64x64xf32>) outs(%29 : tensor<64x64xf16>) {
^bb0(%in: f32, %out: f16):
%33 = arith.truncf %in : f32 to f16
linalg.yield %33 : f16
} -> tensor<64x64xf16>
%31 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%26 : tensor<64xf32>)
outs(%arg3 : tensor<64x64xf32>) {
^bb0(%in: f32, %out: f32):
%33 = arith.mulf %in, %out : f32
linalg.yield %33 : f32
} -> tensor<64x64xf32>
%32 = linalg.matmul ins(%30, %extracted_slice_4 : tensor<64x64xf16>, tensor<64x64xf16>) outs(%31 : tensor<64x64xf32>) -> tensor<64x64xf32>
scf.yield %32, %24, %28 : tensor<64x64xf32>, tensor<64xf32>, tensor<64xf32>
}
%19 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%18#2 : tensor<64xf32>)
outs(%18#0 : tensor<64x64xf32>) {
^bb0(%in: f32, %out: f32):
%cst_3 = arith.constant 1.000000e+00 : f32
%21 = arith.divf %cst_3, %in : f32
%22 = arith.mulf %21, %out : f32
linalg.yield %22 : f32
} -> tensor<64x64xf32>
%20 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%19 :
tensor<64x64xf32>) outs(%extracted_slice : tensor<64x64xf16>) {
^bb0(%in: f32, %out: f16):
%21 = arith.truncf %in : f32 to f16
linalg.yield %21 : f16
} -> tensor<64x64xf16>
%inserted_slice = tensor.insert_slice %20 into %8[0, 0, 0] [1, 64, 64] [1, 1, 1] : tensor<64x64xf16> into tensor<20x64x64xf16>
flow.dispatch.tensor.store %inserted_slice, %3, offsets = [%arg0, %arg1, 0], sizes = [20, 64, 64], strides = [1, 1, 1] : tensor<20x64x64xf16> ->
!flow.dispatch.tensor<writeonly:tensor<20x4096x64xf16>>
}
}
return
}
Turns out attention decomposition is failing during bufferization and we probably dont have any end-to-end tests exercising this (AFAIK all attention op related work has been done only using transform dialect with correctness checks out of tree). So this is probably failing silently. I disabled the decomposition for now. This will unblock other work, while I look into this issue (most like file upstream bug).
I am attaching the current decomposition (repro.mlir) and what I manually fixed (fixed.mlir). The latter fixes the failure in the eliminate empty tensors pass. The current decomposition is doing something funky with destinations and probably needs a look. It could also probably use some elementwise op fusion to make the IR more managable. All in all this needs a new look.
Tagging @harsh-nod , do you think we can iterate on this to get to a better final place.
@MaheshRavishankar - definitely. I will be out next week but @erman-gurses can help pick this up in the meanwhile. Some notes on the problem:
%24 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%23 : tensor<64x64xf32>) outs(%arg4 : tensor<64xf32>) {
^bb0(%in: f32, %out: f32):
%33 = arith.maximumf %in, %out : f32
linalg.yield %33 : f32
} -> tensor<64xf32>
where we use the old max to compute the new max (%24). This is returned at the end of the loop
scf.yield %32, %24, %28 : tensor<64x64xf32>, tensor<64xf32>, tensor<64xf32>
The problem comes when we use this again here
%26 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%24 : tensor<64xf32>) outs(%arg4 : tensor<64xf32>) {
^bb0(%in: f32, %out: f32):
%33 = arith.subf %out, %in : f32
%34 = math.exp2 %33 : f32
linalg.yield %34 : f32
} -> tensor<64xf32>
The idea behind this second usage was the following: Say we allocate some memory for %arg4, then if we allocate some additional memory for %24, we want to reuse the memory we allocated for %arg4 when we do the second generic (rather than allocate some more memory). The idea was to minimize the memory allocations inside the for loop with the assumption that the tensor.empty() might materialize into an allocation.
Another note is that the transform dialect bufferization operator seems to bufferize the original graph. So might be worthwhile to look into the differences between the TD Bufferization Op and what's in the pass pipeline. Specifically, here https://github.com/openxla/iree/blob/6560f8600c6e6a79eaed3ca3b26af2f8e2b900ae/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp#L760 where some new passes are being used such as EmptyTensorLoweringPattern etc.
First thing to look at the output of this command and see if it avoids additional stack allocations
iree-opt --iree-eliminate-empty-tensors --empty-tensor-to-alloc-tensor --iree-codegen-iree-comprehensive-bufferize fixed.mlir
i.e. without trying to manually reuse the buffers, but letting the analysis eliminate it for you. If so, then just changing the tile and decompose attention pass to not reuse buffer this way would fix the issue.
@MaheshRavishankar FWIW:
Followed instructions to run iree-opt --iree-eliminate-empty-tensors --empty-tensor-to-alloc-tensor --iree-codegen-iree-comprehensive-bufferize fixed.mlir
output of iree-opt (above command): minimal_attn_elim.mlir.txt
iree-compile .\minimal_attn_elim.mlir --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=llvm-cpu --iree-llvmcpu-embedded-linker-path=C:\V\iree\build\compiler\bindings\python\iree\compiler\tools\..\_mlir_libs\iree-lld.exe --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --mlir-pass-pipeline-crash-reproducer=./shark_tmp/core-reproducer.mlir --iree-input-type=torch --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host --iree-llvmcpu-target-triple=x86_64-linux-gnu --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 --iree-llvmcpu-enable-ukernels=all -o minimal_attn.vmfb
Assertion failed: index < size() && "invalid index into type range", file C:\V\iree\third_party\llvm-project\mlir\include\mlir/IR/TypeRange.h, line 140
.\minimal_attn_elim.mlir:6:1: error: Failures have been detected while processing an MLIR pass pipeline
module {
^
.\minimal_attn_elim.mlir:6:1: note: Pipeline failed while executing [`FormDispatchRegions` on 'util.func' operation: @main_dispatch_0_attention_20x4096x64xf16]: reproducer generated at `./shark_tmp/core-reproducer.mlir`
noting I've included the empty-tensor-to-alloc-tensor
pass ad-hoc -- this is my diff to init_mlir_passes.h:
diff --git a/compiler/src/iree/compiler/Tools/init_mlir_passes.h b/compiler/src/iree/compiler/Tools/init_mlir_passes.h
index 4ebbdafe0..2d1d17897 100644
--- a/compiler/src/iree/compiler/Tools/init_mlir_passes.h
+++ b/compiler/src/iree/compiler/Tools/init_mlir_passes.h
@@ -19,6 +19,7 @@
#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
#include "mlir/Dialect/GPU/Transforms/Passes.h"
#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/Dialect/SCF/Transforms/Passes.h"
#include "mlir/Dialect/SPIRV/Transforms/Passes.h"
@@ -59,6 +60,9 @@ inline void registerMlirPasses() {
// Arm SME
arm_sme::registerArmSMEPasses();
+ // Bufferization
+ bufferization::registerEmptyTensorToAllocTensor();
+
// Linalg
registerLinalgPasses();
I just verified that fixed.mlir
can pass bufferization, though there are some big buffers. With https://github.com/openxla/iree/pull/16524, we generate the below IRs. There are some stack allocation, which is bounded by 128 and 64. It is okay if they are from tile sizes. Without the fix (i.e., using repro.mlir.txt
), I'm able to see the error. So we probably want to fix the decomposition logic.
#map = affine_map<()[s0] -> (s0 * 128)>
#map1 = affine_map<(d0, d1) -> (d0, d1)>
#map2 = affine_map<(d0, d1) -> (d0)>
#map3 = affine_map<(d0) -> (d0)>
module {
func.func @main_dispatch_0_attention_20x4096x64xf16() {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f32
%cst_0 = arith.constant -1.000000e+30 : f32
%c4096 = arith.constant 4096 : index
%cst_1 = arith.constant 1.000000e+00 : f32
%c128 = arith.constant 128 : index
%alloca = memref.alloca() {alignment = 64 : i64} : memref<128x128xf16>
%alloca_2 = memref.alloca() {alignment = 64 : i64} : memref<128xf32>
%alloca_3 = memref.alloca() {alignment = 64 : i64} : memref<128x128xf32>
%alloca_4 = memref.alloca() {alignment = 64 : i64} : memref<128xf32>
%alloca_5 = memref.alloca() {alignment = 64 : i64} : memref<128xf32>
%alloca_6 = memref.alloca() {alignment = 64 : i64} : memref<128x64xf32>
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<20x4096x64xf16, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %0, 64 : memref<20x4096x64xf16, #hal.descriptor_type<storage_buffer>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<20x4096x64xf16, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %1, 64 : memref<20x4096x64xf16, #hal.descriptor_type<storage_buffer>>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<20x4096x64xf16, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %2, 64 : memref<20x4096x64xf16, #hal.descriptor_type<storage_buffer>>
%3 = hal.interface.binding.subspan set(0) binding(3) type(storage_buffer) alignment(64) offset(%c0) : memref<20x4096x64xf16, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %3, 64 : memref<20x4096x64xf16, #hal.descriptor_type<storage_buffer>>
%workgroup_id_y = hal.interface.workgroup.id[1] : index
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%4 = affine.apply #map()[%workgroup_id_x]
%subview = memref.subview %3[%workgroup_id_y, %4, 0] [1, 128, 64] [1, 1, 1] : memref<20x4096x64xf16, #hal.descriptor_type<storage_buffer>> to memref<1x128x64xf16, strided<[262144, 64, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_7 = memref.subview %0[%workgroup_id_y, %4, 0] [1, 128, 64] [1, 1, 1] : memref<20x4096x64xf16, #hal.descriptor_type<storage_buffer>> to memref<1x128x64xf16, strided<[262144, 64, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_8 = memref.subview %1[%workgroup_id_y, 0, 0] [1, 4096, 64] [1, 1, 1] : memref<20x4096x64xf16, #hal.descriptor_type<storage_buffer>> to memref<1x4096x64xf16, strided<[262144, 64, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_9 = memref.subview %2[%workgroup_id_y, 0, 0] [1, 4096, 64] [1, 1, 1] : memref<20x4096x64xf16, #hal.descriptor_type<storage_buffer>> to memref<1x4096x64xf16, strided<[262144, 64, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_10 = memref.subview %subview[0, 0, 0] [1, 128, 64] [1, 1, 1] : memref<1x128x64xf16, strided<[262144, 64, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<128x64xf16, strided<[64, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
linalg.fill ins(%cst : f32) outs(%alloca_6 : memref<128x64xf32>)
linalg.fill ins(%cst_0 : f32) outs(%alloca_4 : memref<128xf32>)
linalg.fill ins(%cst : f32) outs(%alloca_5 : memref<128xf32>)
scf.for %arg0 = %c0 to %c4096 step %c128 {
%subview_11 = memref.subview %subview_8[0, %arg0, 0] [1, 128, 64] [1, 1, 1] : memref<1x4096x64xf16, strided<[262144, 64, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<128x64xf16, strided<[64, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
linalg.fill ins(%cst : f32) outs(%alloca_3 : memref<128x128xf32>)
%subview_12 = memref.subview %subview_7[0, 0, 0] [1, 128, 64] [1, 1, 1] : memref<1x128x64xf16, strided<[262144, 64, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<128x64xf16, strided<[64, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
linalg.matmul_transpose_b ins(%subview_12, %subview_11 : memref<128x64xf16, strided<[64, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, memref<128x64xf16, strided<[64, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>) outs(%alloca_3 : memref<128x128xf32>)
linalg.generic {indexing_maps = [#map1, #map2], iterator_types = ["parallel", "reduction"]} ins(%alloca_3 : memref<128x128xf32>) outs(%alloca_4 : memref<128xf32>) {
^bb0(%in: f32, %out: f32):
%5 = arith.maximumf %in, %out : f32
linalg.yield %5 : f32
}
linalg.generic {indexing_maps = [#map2, #map1], iterator_types = ["parallel", "parallel"]} ins(%alloca_4 : memref<128xf32>) outs(%alloca_3 : memref<128x128xf32>) {
^bb0(%in: f32, %out: f32):
%5 = arith.subf %out, %in : f32
%6 = math.exp2 %5 : f32
linalg.yield %6 : f32
}
linalg.generic {indexing_maps = [#map3, #map3], iterator_types = ["parallel"]} ins(%alloca_4 : memref<128xf32>) outs(%alloca_2 : memref<128xf32>) {
^bb0(%in: f32, %out: f32):
%5 = arith.subf %out, %in : f32
%6 = math.exp2 %5 : f32
linalg.yield %6 : f32
}
linalg.generic {indexing_maps = [#map3, #map3], iterator_types = ["parallel"]} ins(%alloca_2 : memref<128xf32>) outs(%alloca_5 : memref<128xf32>) {
^bb0(%in: f32, %out: f32):
%5 = arith.mulf %in, %out : f32
linalg.yield %5 : f32
}
linalg.generic {indexing_maps = [#map1, #map2], iterator_types = ["parallel", "reduction"]} ins(%alloca_3 : memref<128x128xf32>) outs(%alloca_5 : memref<128xf32>) {
^bb0(%in: f32, %out: f32):
%5 = arith.addf %in, %out : f32
linalg.yield %5 : f32
}
linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%alloca_3 : memref<128x128xf32>) outs(%alloca : memref<128x128xf16>) {
^bb0(%in: f32, %out: f16):
%5 = arith.truncf %in : f32 to f16
linalg.yield %5 : f16
}
linalg.generic {indexing_maps = [#map2, #map1], iterator_types = ["parallel", "parallel"]} ins(%alloca_2 : memref<128xf32>) outs(%alloca_6 : memref<128x64xf32>) {
^bb0(%in: f32, %out: f32):
%5 = arith.mulf %in, %out : f32
linalg.yield %5 : f32
}
%subview_13 = memref.subview %subview_9[0, %arg0, 0] [1, 128, 64] [1, 1, 1] : memref<1x4096x64xf16, strided<[262144, 64, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<128x64xf16, strided<[64, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
linalg.matmul ins(%alloca, %subview_13 : memref<128x128xf16>, memref<128x64xf16, strided<[64, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>) outs(%alloca_6 : memref<128x64xf32>)
}
linalg.generic {indexing_maps = [#map2, #map1], iterator_types = ["parallel", "parallel"]} ins(%alloca_5 : memref<128xf32>) outs(%alloca_6 : memref<128x64xf32>) {
^bb0(%in: f32, %out: f32):
%5 = arith.divf %cst_1, %in : f32
%6 = arith.mulf %5, %out : f32
linalg.yield %6 : f32
}
linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%alloca_6 : memref<128x64xf32>) outs(%subview_10 : memref<128x64xf16, strided<[64, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>) {
^bb0(%in: f32, %out: f16):
%5 = arith.truncf %in : f32 to f16
linalg.yield %5 : f16
}
return
}
}
I have a fix (which generates the IR that @MaheshRavishankar suggested) for attention on CPU side. Let me prepare a PR.
There are big stack allocation issues, but you can bypass it with --iree-llvmcpu-fail-on-out-of-bounds-stack-allocation=false
. Otherwise, you'll see the below error. I think they are bounded by distribution tile sizes (which is tile_sizes = [[20, 64]]
in this case), so it is fine for now.
/Users/hanchung/z.mlir:3:8: error: 'func.func' op exceeded stack allocation limit of 32768 bytes for function. Got 41728 bytes
%0 = iree_linalg_ext.attention ins(%arg0, %arg1, %arg2 : tensor<20x4096x64xf16>, tensor<20x4096x64xf16>, tensor<20x4096x64xf16>) outs(%empty : tensor<20x4096x64xf16>) -> tensor<20x4096x64xf16>
...
%24 = "memref.alloca"() <{alignment = 64 : i64, operandSegmentSizes = array<i32: 0, 0>}> : () -> memref<64x64xf16>
%25 = "memref.alloca"() <{alignment = 64 : i64, operandSegmentSizes = array<i32: 0, 0>}> : () -> memref<64xf32>
%26 = "memref.alloca"() <{alignment = 64 : i64, operandSegmentSizes = array<i32: 0, 0>}> : () -> memref<64x64xf32>
%27 = "memref.alloca"() <{alignment = 64 : i64, operandSegmentSizes = array<i32: 0, 0>}> : () -> memref<64xf32>
%28 = "memref.alloca"() <{alignment = 64 : i64, operandSegmentSizes = array<i32: 0, 0>}> : () -> memref<64xf32>
%29 = "memref.alloca"() <{alignment = 64 : i64, operandSegmentSizes = array<i32: 0, 0>}> : () -> memref<64x64xf32>
I'm able to compile the op e2e on CPU. To repro, run iree-compile --output-format=vm-bytecode --iree-hal-target-backends=llvm-cpu --iree-llvmcpu-target-cpu=cascadelake --iree-llvmcpu-target-triple=x86_64-unknown-linux-gnu ~/z.mlir -o /tmp/a.vmfb --iree-llvmcpu-fail-on-out-of-bounds-stack-allocation=false
.
There are big stack allocation issues, but you can bypass it with
--iree-llvmcpu-fail-on-out-of-bounds-stack-allocation=false
. Otherwise, you'll see the below error. I think they are bounded by distribution tile sizes (which istile_sizes = [[20, 64]]
in this case), so it is fine for now./Users/hanchung/z.mlir:3:8: error: 'func.func' op exceeded stack allocation limit of 32768 bytes for function. Got 41728 bytes %0 = iree_linalg_ext.attention ins(%arg0, %arg1, %arg2 : tensor<20x4096x64xf16>, tensor<20x4096x64xf16>, tensor<20x4096x64xf16>) outs(%empty : tensor<20x4096x64xf16>) -> tensor<20x4096x64xf16> ... %24 = "memref.alloca"() <{alignment = 64 : i64, operandSegmentSizes = array<i32: 0, 0>}> : () -> memref<64x64xf16> %25 = "memref.alloca"() <{alignment = 64 : i64, operandSegmentSizes = array<i32: 0, 0>}> : () -> memref<64xf32> %26 = "memref.alloca"() <{alignment = 64 : i64, operandSegmentSizes = array<i32: 0, 0>}> : () -> memref<64x64xf32> %27 = "memref.alloca"() <{alignment = 64 : i64, operandSegmentSizes = array<i32: 0, 0>}> : () -> memref<64xf32> %28 = "memref.alloca"() <{alignment = 64 : i64, operandSegmentSizes = array<i32: 0, 0>}> : () -> memref<64xf32> %29 = "memref.alloca"() <{alignment = 64 : i64, operandSegmentSizes = array<i32: 0, 0>}> : () -> memref<64x64xf32>
I'm able to compile the op e2e on CPU. To repro, run
iree-compile --output-format=vm-bytecode --iree-hal-target-backends=llvm-cpu --iree-llvmcpu-target-cpu=cascadelake --iree-llvmcpu-target-triple=x86_64-unknown-linux-gnu ~/z.mlir -o /tmp/a.vmfb --iree-llvmcpu-fail-on-out-of-bounds-stack-allocation=false
.
can we make it [[20, 32]]
and it fits within the limit?
can we make it
[[20, 32]]
and it fits within the limit?
Yes, we can do it with --iree-llvmcpu-distribution-size=32
flag. I think we need a specialized setRootConfig
entry function for the op (or all the LinalgExt ops). Because all of them go with CPUDefault
pipeline, which only apply distribution and bufferization.
@monorimet https://github.com/openxla/iree/issues/16421#issuecomment-1958397924 should unblock running the model on CPU backends (using the flag to disable failure on exceeding stack allocation limit). Can you verify (though on GPU if it isnt within the limit, it wont run.
@erman-gurses after Hanhan's fixes, can you add an e2e test in IREE with attention op, so that we can catch this issue.
@erman-gurses after Hanhan's fixes, can you add an e2e test in IREE with attention op, so that we can catch this issue. @MaheshRavishankar, Sure, I can work on that.
@erman-gurses after Hanhan's fixes, can you add an e2e test in IREE with attention op, so that we can catch this issue. @MaheshRavishankar, Sure, I can work on that.
Thanks, please add e2e tests to https://github.com/openxla/iree/tree/main/tests/e2e/linalg_ext_ops
The following command successfully finishes compilation for me with these commits cherrypicked:
iree-compile .\minimal_attn.mlir --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=llvm-cpu --iree-llvmcpu-embedded-linker-path=C:\V\iree\build\compiler\bindings\python\iree\compiler\tools\..\_mlir_libs\iree-lld.exe --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --mlir-pass-pipeline-crash-reproducer=./shark_tmp/core-reproducer.mlir --iree-input-type=torch --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host --iree-llvmcpu-target-triple=x86_64-linux-gnu --iree-llvmcpu-enable-ukernels=all --iree-llvmcpu-distribution-size=32 -o minimal_attn.vmfb
Thank you @hanhanW
I will try on other backends as well
Opening the issue until we verify that it's fixed for other backends as well.
ROCM runs into shared memory allocation limit with attention tiled+decomposed: https://github.com/openxla/iree/issues/16538
Vulkan is tricky.
Stumbling around SPIRV KernelConfig it seems that we just don't have a good pipeline for this decomposition -- I haven't had any luck dropping in LinalgExt::TileAndDecomposeAttentionPass
anywhere in addTileAndDistributeToWorkgroupsPasses to mimic the LLVMCPU/LLVMGPU implementation, mostly running into issues turning the result into valid MMA Subgroup Compute ops, perhaps it needs a different or slightly more bespoke pipeline to be set for these attention op dispatches... Very open to suggestions here.
@erman-gurses @Eliasj42 let me know if you two found anything workable for this (we will see it in vulkan + SDXL UNet, VAE)
It's a wrong fix, so it introduces numerical issues. I created a revert https://github.com/openxla/iree/pull/16559
I will figure how to fix it correctly.
I just realized that my fix is still wrong.. it does not consider the update max slice correctly. There are two issues about attention.
I'm gonna take a look at (1) and can probably implement (2). If anyone can help on (2), that would be great. I can point out where to add the code.
(2) makes sure that we have basic coverages for all the backends, including VMVX.
I noticed that we can pass bufferization if we vectorize all the operations. I think the final goal is to have vectorization working for all the dispatches. So I'd like to create a new pipeline for attention op on CPU side: https://github.com/openxla/iree/pull/16577
@harsh-nod looking at this
%0 = linalg.generic {indexing_maps = [#map2, #map1], iterator_types = ["parallel", "parallel"]} ins(%4 : tensor<128xf32>) outs(%3 : tensor<128x128xf32>) {
^bb0(%in: f32, %out: f32):
%5 = arith.subf %out, %in : f32
%6 = math.exp2 %5 : f32
linalg.yield %6 : f32
} -> tensor<128x128xf32>
These kinds of linalg.generic
is kind of an anti-pattern. Ideally the %out
value is never "read" from within the body when the iterators are marked as all parallel. It isnt wrong per-se, but such uses are better avoided. I understand why you are doing it. A better way to do this would be to use bufferization.to_tensor
or similar operations to try to reuse memory.
I just realized that my fix is still wrong.. it does not consider the update max slice correctly. There are two issues about attention.
- We need to teach IREE bufferization about the decomposed ops.
- We need to implement LinalgExtToLoops for the op, so we can at least have scalar fallback solution always.
I'm gonna take a look at (1) and can probably implement (2). If anyone can help on (2), that would be great. I can point out where to add the code.
@hanhanW I can work on the scalar fallback solution. I think we have a similar implementation here: https://github.com/gpetters94/mlir-npcomp/blob/3e30bb06c0cd725a32f1091552d4824bd796a2d6/externals/llvm-external-projects/torch-mlir-dialects/lib/Dialect/TMTensor/IR/TMTensorOps.cpp#L179
@hanhanW I can work on the scalar fallback solution. I think we have a similar implementation here: https://github.com/gpetters94/mlir-npcomp/blob/3e30bb06c0cd725a32f1091552d4824bd796a2d6/externals/llvm-external-projects/torch-mlir-dialects/lib/Dialect/TMTensor/IR/TMTensorOps.cpp#L179
Thanks for offering the help, that's very helpful!
What happened?
We have IR for SDXL that preserves the torch.aten._scaled_dot_product_attention op as tm_tensor.attention -> iree_linalg_ext.attention, but I'm seeing issues trying to lower the resulting IR through IREE-compiler.
@rsuderman managed to narrow the issue down to comprehensive bufferization of the tiled and decomposed attention op.
It seems that a buffer/address is being reused many times in the inner loop, causing bufferization/analysis to fail, though the error message given only mentions a mismatched yield / iterArg pair.
The minimized IR (minimal_attn.mlir):
The CLI input:
The error message:
The full log with --mlir-print-ir-after-all:
out.txt
What component(s) does this issue relate to?
iree-compiler, one-shot bufferization/analysis
Version information
Reproduced on latest IREE (0c61f77) and on a source build from contents of this PR: https://github.com/openxla/iree/pull/16416
Additional context
No response