Open ziereis opened 2 weeks ago
@hanhanW seems like something with the vectorizer potentially?
Cc @pashu123 as well. Seems like a tile size issue
@ziereis Could you change https://github.com/iree-org/iree/blob/f42b90d23c332bee6dedd1c8f44e07b9b1a52f74/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp#L408 with
funcPassManager.addPass(createLLVMCPUTileRootAndFuseInputOperands(i));
and try.
@pashu123 i tested it with this example and a couple other ones that failed and they all compile with this fix.
@ziereis For context, this was introduced in https://github.com/iree-org/iree/pull/18114, but we only enabled this for convExpert Pipeline.
I can not reproduce the issue because the target CPU is not specified. Can you provide the log with --mlir-print-ir-after-all --mlir-disable-threading
?
btw, I think the issue is not related to LLVMCPUTileRootAndFuseInputOperands
? They are two reductions and they are formed into different dispatches. The issue is that we get large tile sizes in lowering_config.
I can not reproduce the issue because the target CPU is not specified. Can you provide the log with
--mlir-print-ir-after-all --mlir-disable-threading
?btw, I think the issue is not related to
LLVMCPUTileRootAndFuseInputOperands
? They are two reductions and they are formed into different dispatches. The issue is that we get large tile sizes in lowering_config.
I was also surprised, but when I looked at the dispatches, the last was one fused unpack + reduction.
I see. I think they are batch_matmul in generic op form, so data-tiling is kicked in. And I don't have cpu_features because the cpu target is not specified, so those encodings are dropped. Thus I'm not able to reproduce it. It's easier if @ziereis can provide the IR dumps.
sorry for not providing the flags. Here is the full command:
./build/tools/iree-compile --iree-hal-target-device=llvm-cpu --iree-llvmcpu-target-cpu=znver4 reproducer.mlir -o out.vmfb
The ir dump is also attached
Here is the IR before vectorization:
// -----// IR Dump After TensorToVectorVectorizePadPass (iree-codegen-vectorize-tensor-pad) //----- //
func.func @main_dispatch_4_unpack_generic_10x128x256_i32xf32xf32xf32() attributes {translation_info = #iree_codegen.translation_info<CPUDoubleTilingExpert>} {
%c256 = arith.constant 256 : index
%c16 = arith.constant 16 : index
%c32 = arith.constant 32 : index
%c1179648 = arith.constant 1179648 : index
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(0) alignment(64) offset(%c1179648) flags("ReadOnly|Indirect") : !flow.dispatch.tensor<readonly:tensor<256x1x8x16x16xi32>>
%1 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(1) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : !flow.dispatch.tensor<readonly:tensor<10x256xf32>>
%2 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(2) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : !flow.dispatch.tensor<readonly:tensor<128x256xf32>>
%3 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(3) alignment(64) offset(%c0) flags(Indirect) : !flow.dispatch.tensor<writeonly:tensor<10x128xf32>>
%4 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0, 0, 0], sizes = [256, 1, 8, 16, 16], strides = [1, 1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<256x1x8x16x16xi32>> -> tensor<256x1x8x16x16xi32>
%5 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [10, 256], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<10x256xf32>> -> tensor<10x256xf32>
%6 = flow.dispatch.tensor.load %2, offsets = [0, 0], sizes = [128, 256], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<128x256xf32>> -> tensor<128x256xf32>
%7 = tensor.empty() : tensor<10x128xf32>
%8 = scf.forall (%arg0) = (0) to (128) step (32) shared_outs(%arg1 = %7) -> (tensor<10x128xf32>) {
%9 = affine.apply affine_map<(d0) -> (d0 floordiv 16)>(%arg0)
%extracted_slice = tensor.extract_slice %4[0, 0, %9, 0, 0] [256, 1, 2, 16, 16] [1, 1, 1, 1, 1] : tensor<256x1x8x16x16xi32> to tensor<256x1x2x16x16xi32>
%extracted_slice_0 = tensor.extract_slice %6[%arg0, 0] [32, 256] [1, 1] : tensor<128x256xf32> to tensor<32x256xf32>
%extracted_slice_1 = tensor.extract_slice %arg1[0, %arg0] [10, 32] [1, 1] : tensor<10x128xf32> to tensor<10x32xf32>
%10 = scf.for %arg2 = %c0 to %c32 step %c16 iter_args(%arg3 = %extracted_slice_1) -> (tensor<10x32xf32>) {
%11 = affine.apply affine_map<(d0) -> (d0 floordiv 16)>(%arg2)
%extracted_slice_2 = tensor.extract_slice %extracted_slice[0, 0, %11, 0, 0] [256, 1, 1, 16, 16] [1, 1, 1, 1, 1] : tensor<256x1x2x16x16xi32> to tensor<256x1x1x16x16xi32>
%12 = tensor.empty() : tensor<10x16x256xi32>
%unpack = tensor.unpack %extracted_slice_2 outer_dims_perm = [2, 0, 1] inner_dims_pos = [0, 1] inner_tiles = [16, 16] into %12 {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[16, 32, 0], [16, 16, 0], [0, 0, 0], [0, 0, 0]]>} : tensor<256x1x1x16x16xi32> -> tensor<10x16x256xi32>
%extracted_slice_3 = tensor.extract_slice %extracted_slice_0[%arg2, 0] [16, 256] [1, 1] : tensor<32x256xf32> to tensor<16x256xf32>
%extracted_slice_4 = tensor.extract_slice %arg3[0, %arg2] [10, 16] [1, 1] : tensor<10x32xf32> to tensor<10x16xf32>
%13 = scf.for %arg4 = %c0 to %c256 step %c16 iter_args(%arg5 = %extracted_slice_4) -> (tensor<10x16xf32>) {
%extracted_slice_5 = tensor.extract_slice %unpack[0, 0, %arg4] [10, 16, 16] [1, 1, 1] : tensor<10x16x256xi32> to tensor<10x16x16xi32>
%extracted_slice_6 = tensor.extract_slice %5[0, %arg4] [10, 16] [1, 1] : tensor<10x256xf32> to tensor<10x16xf32>
%extracted_slice_7 = tensor.extract_slice %extracted_slice_3[0, %arg4] [16, 16] [1, 1] : tensor<16x256xf32> to tensor<16x16xf32>
%14 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%extracted_slice_5, %extracted_slice_6, %extracted_slice_7 : tensor<10x16x16xi32>, tensor<10x16xf32>, tensor<16x16xf32>) outs(%arg5 : tensor<10x16xf32>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[16, 32, 0], [16, 16, 0], [0, 0, 16], [0, 0, 0]]>} {
^bb0(%in: i32, %in_8: f32, %in_9: f32, %out: f32):
%15 = arith.sitofp %in : i32 to f32
%16 = arith.mulf %in_9, %in_8 : f32
%17 = arith.mulf %16, %15 : f32
%18 = arith.addf %17, %out : f32
linalg.yield %18 : f32
} -> tensor<10x16xf32>
scf.yield %14 : tensor<10x16xf32>
}
%inserted_slice = tensor.insert_slice %13 into %arg3[0, %arg2] [10, 16] [1, 1] : tensor<10x16xf32> into tensor<10x32xf32>
scf.yield %inserted_slice : tensor<10x32xf32>
}
scf.forall.in_parallel {
tensor.parallel_insert_slice %10 into %arg1[0, %arg0] [10, 32] [1, 1] : tensor<10x32xf32> into tensor<10x128xf32>
}
} {mapping = [#iree_codegen.workgroup_mapping<x>]}
flow.dispatch.tensor.store %8, %3, offsets = [0, 0], sizes = [10, 128], strides = [1, 1] : tensor<10x128xf32> -> !flow.dispatch.tensor<writeonly:tensor<10x128xf32>>
return
}
I've been thinking that if we want to fuse the unpack ops into the reduction loops; I think the answer is yes. We don't want to fuse packing ops into matmul because of data reuse. I think it is okay for reduction ops because it is the source and there are no redundant packings. So what Prashant suggested makes sense to me. @pashu123 what is the status of the work? Was this case in our initial scope of the work?
Here is the IR before vectorization:
// -----// IR Dump After TensorToVectorVectorizePadPass (iree-codegen-vectorize-tensor-pad) //----- // func.func @main_dispatch_4_unpack_generic_10x128x256_i32xf32xf32xf32() attributes {translation_info = #iree_codegen.translation_info
} { %c256 = arith.constant 256 : index %c16 = arith.constant 16 : index %c32 = arith.constant 32 : index %c1179648 = arith.constant 1179648 : index %c0 = arith.constant 0 : index %0 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(0) alignment(64) offset(%c1179648) flags("ReadOnly|Indirect") : !flow.dispatch.tensor<readonly:tensor<256x1x8x16x16xi32>> %1 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(1) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : !flow.dispatch.tensor<readonly:tensor<10x256xf32>> %2 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(2) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : !flow.dispatch.tensor<readonly:tensor<128x256xf32>> %3 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(3) alignment(64) offset(%c0) flags(Indirect) : !flow.dispatch.tensor<writeonly:tensor<10x128xf32>> %4 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0, 0, 0], sizes = [256, 1, 8, 16, 16], strides = [1, 1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<256x1x8x16x16xi32>> -> tensor<256x1x8x16x16xi32> %5 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [10, 256], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<10x256xf32>> -> tensor<10x256xf32> %6 = flow.dispatch.tensor.load %2, offsets = [0, 0], sizes = [128, 256], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<128x256xf32>> -> tensor<128x256xf32> %7 = tensor.empty() : tensor<10x128xf32> %8 = scf.forall (%arg0) = (0) to (128) step (32) shared_outs(%arg1 = %7) -> (tensor<10x128xf32>) { %9 = affine.apply affine_map<(d0) -> (d0 floordiv 16)>(%arg0) %extracted_slice = tensor.extract_slice %4[0, 0, %9, 0, 0] [256, 1, 2, 16, 16] [1, 1, 1, 1, 1] : tensor<256x1x8x16x16xi32> to tensor<256x1x2x16x16xi32> %extracted_slice_0 = tensor.extract_slice %6[%arg0, 0] [32, 256] [1, 1] : tensor<128x256xf32> to tensor<32x256xf32> %extracted_slice_1 = tensor.extract_slice %arg1[0, %arg0] [10, 32] [1, 1] : tensor<10x128xf32> to tensor<10x32xf32> %10 = scf.for %arg2 = %c0 to %c32 step %c16 iter_args(%arg3 = %extracted_slice_1) -> (tensor<10x32xf32>) { %11 = affine.apply affine_map<(d0) -> (d0 floordiv 16)>(%arg2) %extracted_slice_2 = tensor.extract_slice %extracted_slice[0, 0, %11, 0, 0] [256, 1, 1, 16, 16] [1, 1, 1, 1, 1] : tensor<256x1x2x16x16xi32> to tensor<256x1x1x16x16xi32> %12 = tensor.empty() : tensor<10x16x256xi32> %unpack = tensor.unpack %extracted_slice_2 outer_dims_perm = [2, 0, 1] inner_dims_pos = [0, 1] inner_tiles = [16, 16] into %12 {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[16, 32, 0], [16, 16, 0], [0, 0, 0], [0, 0, 0]]>} : tensor<256x1x1x16x16xi32> -> tensor<10x16x256xi32> %extracted_slice_3 = tensor.extract_slice %extracted_slice_0[%arg2, 0] [16, 256] [1, 1] : tensor<32x256xf32> to tensor<16x256xf32> %extracted_slice_4 = tensor.extract_slice %arg3[0, %arg2] [10, 16] [1, 1] : tensor<10x32xf32> to tensor<10x16xf32> %13 = scf.for %arg4 = %c0 to %c256 step %c16 iter_args(%arg5 = %extracted_slice_4) -> (tensor<10x16xf32>) { %extracted_slice_5 = tensor.extract_slice %unpack[0, 0, %arg4] [10, 16, 16] [1, 1, 1] : tensor<10x16x256xi32> to tensor<10x16x16xi32> %extracted_slice_6 = tensor.extract_slice %5[0, %arg4] [10, 16] [1, 1] : tensor<10x256xf32> to tensor<10x16xf32> %extracted_slice_7 = tensor.extract_slice %extracted_slice_3[0, %arg4] [16, 16] [1, 1] : tensor<16x256xf32> to tensor<16x16xf32> %14 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%extracted_slice_5, %extracted_slice_6, %extracted_slice_7 : tensor<10x16x16xi32>, tensor<10x16xf32>, tensor<16x16xf32>) outs(%arg5 : tensor<10x16xf32>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[16, 32, 0], [16, 16, 0], [0, 0, 16], [0, 0, 0]]>} { ^bb0(%in: i32, %in_8: f32, %in_9: f32, %out: f32): %15 = arith.sitofp %in : i32 to f32 %16 = arith.mulf %in_9, %in_8 : f32 %17 = arith.mulf %16, %15 : f32 %18 = arith.addf %17, %out : f32 linalg.yield %18 : f32 } -> tensor<10x16xf32> scf.yield %14 : tensor<10x16xf32> } %inserted_slice = tensor.insert_slice %13 into %arg3[0, %arg2] [10, 16] [1, 1] : tensor<10x16xf32> into tensor<10x32xf32> scf.yield %inserted_slice : tensor<10x32xf32> } scf.forall.in_parallel { tensor.parallel_insert_slice %10 into %arg1[0, %arg0] [10, 32] [1, 1] : tensor<10x32xf32> into tensor<10x128xf32> } } {mapping = [#iree_codegen.workgroup_mapping ]} flow.dispatch.tensor.store %8, %3, offsets = [0, 0], sizes = [10, 128], strides = [1, 1] : tensor<10x128xf32> -> !flow.dispatch.tensor<writeonly:tensor<10x128xf32>> return } I've been thinking that if we want to fuse the unpack ops into the reduction loops; I think the answer is yes. We don't want to fuse packing ops into matmul because of data reuse. I think it is okay for reduction ops because it is the source and there are no redundant packings. So what Prashant suggested makes sense to me. @pashu123 what is the status of the work? Was this case in our initial scope of the work?
We are eventually shifting to the newer tile&fuse pipeline. https://github.com/iree-org/iree/pull/19163 is the first PR.
What happened?
Compilation to llvm-cpu fails with error: One or more operations with large vector sizes (8192 bytes) were found
Input IR:
This fails to compile, by changing the second dimensions of the tensors i.e. 256 in this case you can get it to compile. For example 32 works.
example error:
Steps to reproduce your issue
iree-compile --iree-hal-target-device=llvm-cpu input.mlir
What component(s) does this issue relate to?
Compiler
Version information
9c85e30
Additional context
No response