iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.85k stars 619 forks source link

IR explosion on big `linalg.generic` reduction #16995

Open Hardcode84 opened 7 months ago

Hardcode84 commented 7 months ago

What happened?

The following code freezes iree-compile for multiple minutes, generating huge dump with --mlir-print-ir-after-all and eventually fails to compile:

#trait2 = {
  indexing_maps = [
    affine_map<(i, j) -> (i, j)>,  // a
    affine_map<(i, j) -> (i, j)>,  // b
    affine_map<(i, j) -> (0, 0)>   // x (out)
  ],
  iterator_types = ["reduction", "reduction"]
}

func.func @test() {
  %lhs = util.unfoldable_constant dense<1.00> : tensor<3456x2048xf16>
  %rhs1 = util.unfoldable_constant dense<0.01> : tensor<2048x1024xf16>
  %rhs2 = util.unfoldable_constant dense<0.01> : tensor<2048x1024xf16>
  %res = util.unfoldable_constant dense<0.01> : tensor<1x1xf16>
  %c0 = arith.constant 0.0 : f16
  %rhs = linalg.generic #trait2
     ins(%rhs1, %rhs2: tensor<2048x1024xf16>, tensor<2048x1024xf16>)
    outs(%res: tensor<1x1xf16>) {
      ^bb(%a: f16, %b: f16, %x: f16):
        %0 = arith.mulf %a, %b : f16
        %1 = arith.addf %0, %x : f16
        linalg.yield %1 : f16
  } -> tensor<1x1xf16>
  check.expect_almost_eq_const(%rhs, dense<20.2822> : tensor<1x1xf16>) : tensor<1x1xf16>
  return
}

Steps to reproduce your issue

iree-compile --iree-hal-target-backends=vulkan-spirv test-reduction.mlir --mlir-print-ir-after-all --iree-vulkan-target-triple=rdna3-unknown-unknown -o test.vmfb > out.txt 2>&1

What component(s) does this issue relate to?

No response

Version information

bb7e536ecd31ea24007aa7202f3a0c41b897e05c

Additional context

SPIRVInitialVectorLowering pass causes IR explosion with a lot of repeated lines:

...
      %11439 = arith.addf %5296, %11438 : vector<4xf16>
      %11440 = vector.insert_strided_slice %11439, %11437 {offsets = [132], strides = [1]} : vector<4xf16> into vector<8192xf16>
      %11441 = vector.extract_strided_slice %arg1 {offsets = [136], sizes = [4], strides = [1]} : vector<8192xf16> to vector<4xf16>
      %11442 = arith.addf %5299, %11441 : vector<4xf16>
      %11443 = vector.insert_strided_slice %11442, %11440 {offsets = [136], strides = [1]} : vector<4xf16> into vector<8192xf16>
      %11444 = vector.extract_strided_slice %arg1 {offsets = [140], sizes = [4], strides = [1]} : vector<8192xf16> to vector<4xf16>
      %11445 = arith.addf %5302, %11444 : vector<4xf16>
      %11446 = vector.insert_strided_slice %11445, %11443 {offsets = [140], strides = [1]} : vector<4xf16> into vector<8192xf16>
      %11447 = vector.extract_strided_slice %arg1 {offsets = [144], sizes = [4], strides = [1]} : vector<8192xf16> to vector<4xf16>
      %11448 = arith.addf %5305, %11447 : vector<4xf16>
      %11449 = vector.insert_strided_slice %11448, %11446 {offsets = [144], strides = [1]} : vector<4xf16> into vector<8192xf16>
      %11450 = vector.extract_strided_slice %arg1 {offsets = [148], sizes = [4], strides = [1]} : vector<8192xf16> to vector<4xf16>
      %11451 = arith.addf %5308, %11450 : vector<4xf16>
      %11452 = vector.insert_strided_slice %11451, %11449 {offsets = [148], strides = [1]} : vector<4xf16> into vector<8192xf16>
      %11453 = vector.extract_strided_slice %arg1 {offsets = [152], sizes = [4], strides = [1]} : vector<8192xf16> to vector<4xf16>
      %11454 = arith.addf %5311, %11453 : vector<4xf16>
      %11455 = vector.insert_strided_slice %11454, %11452 {offsets = [152], strides = [1]} : vector<4xf16> into vector<8192xf16>
      %11456 = vector.extract_strided_slice %arg1 {offsets = [156], sizes = [4], strides = [1]} : vector<8192xf16> to vector<4xf16>
      %11457 = arith.addf %5314, %11456 : vector<4xf16>
...
msustik commented 5 months ago

@Hardcode84 I am looking at this issue. I am confused about %lhs. It does not appear to be used. Where is this mlir code coming from?

Secondly, I reduced the tensor dimensions to find a smaller example that demonstrates the fail. When I rescaled from 2048 to 128 (and accordingly the other dimensions) the compile finished and a vmfb file was created. With the slightly larger 2048->144 scaling the error still showed up.

Hardcode84 commented 5 months ago

%lhs is probably result of copypaste from another test, you can remove it, it shouldn't affect the reproducer. Also expect_almost_eq_const value is arbitrary, I've only added it here to prevent DCE.