iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.82k stars 609 forks source link

failed to legalize operation 'stream.cmd.dispatch' that was explicitly marked illegal #18631

Open pdhirajkumarprasad opened 1 month ago

pdhirajkumarprasad commented 1 month ago

What happened?

For the give IR

#map = affine_map<() -> ()>
#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
#map2 = affine_map<(d0, d1, d2, d3) -> ()>
#map3 = affine_map<()[s0] -> (s0 floordiv 12)>
module {
  ml_program.global private mutable @global_seed(dense<0> : tensor<i64>) : tensor<i64>
  func.func @main_graph(%arg0: tensor<?x?xi64>, %arg1: tensor<?x?xi64>, %arg2: tensor<?x?xi1>, %arg3: tensor<?x12x64x?xf32>, %arg4: tensor<?x?x768xf32>, %arg5: tensor<4xi64>, %arg6: tensor<?x?x12x64xf32>) -> tensor<?x12x?x?xf32> {
    %cst = arith.constant dense<8.000000e+00> : tensor<f32>
    %c1_i64 = arith.constant 1 : i64
    %c0_i64 = arith.constant 0 : i64
    %c0 = arith.constant 0 : index
    %c1 = arith.constant 1 : index
    %c12 = arith.constant 12 : index
    %c3 = arith.constant 3 : index
    %cst_0 = arith.constant 0.000000e+00 : f32
    %c768_i64 = arith.constant 768 : i64
    %cst_1 = arith.constant dense<-3.40282347E+38> : tensor<f32>
    %cst_2 = arith.constant dense<768> : tensor<i64>
    %extracted_slice = tensor.extract_slice %arg5[0] [1] [1] : tensor<4xi64> to tensor<1xi64>
    %extracted = tensor.extract %extracted_slice[%c0] : tensor<1xi64>
    %0 = arith.cmpi eq, %extracted, %c0_i64 : i64
    %dim = tensor.dim %arg4, %c0 : tensor<?x?x768xf32>
    %1 = arith.index_cast %dim : index to i64
    %2 = tensor.empty() : tensor<i1>
    %3 = linalg.fill ins(%0 : i1) outs(%2 : tensor<i1>) -> tensor<i1>
    %4 = tensor.empty() : tensor<i64>
    %5 = linalg.fill ins(%1 : i64) outs(%4 : tensor<i64>) -> tensor<i64>
    %6 = linalg.fill ins(%extracted : i64) outs(%4 : tensor<i64>) -> tensor<i64>
    %7 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = []} ins(%3, %5, %6 : tensor<i1>, tensor<i64>, tensor<i64>) outs(%4 : tensor<i64>) {
    ^bb0(%in: i1, %in_26: i64, %in_27: i64, %out: i64):
      %108 = arith.select %in, %in_26, %in_27 : i64
      linalg.yield %108 : i64
    } -> tensor<i64>
    %extracted_3 = tensor.extract %7[] : tensor<i64>
    %extracted_slice_4 = tensor.extract_slice %arg5[1] [1] [1] : tensor<4xi64> to tensor<1xi64>
    %extracted_5 = tensor.extract %extracted_slice_4[%c0] : tensor<1xi64>
    %8 = arith.cmpi eq, %extracted_5, %c0_i64 : i64
    %dim_6 = tensor.dim %arg4, %c1 : tensor<?x?x768xf32>
    %9 = arith.index_cast %dim_6 : index to i64
    %10 = linalg.fill ins(%8 : i1) outs(%2 : tensor<i1>) -> tensor<i1>
    %11 = linalg.fill ins(%9 : i64) outs(%4 : tensor<i64>) -> tensor<i64>
    %12 = linalg.fill ins(%extracted_5 : i64) outs(%4 : tensor<i64>) -> tensor<i64>
    %13 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = []} ins(%10, %11, %12 : tensor<i1>, tensor<i64>, tensor<i64>) outs(%4 : tensor<i64>) {
    ^bb0(%in: i1, %in_26: i64, %in_27: i64, %out: i64):
      %108 = arith.select %in, %in_26, %in_27 : i64
      linalg.yield %108 : i64
    } -> tensor<i64>
    %extracted_7 = tensor.extract %13[] : tensor<i64>
    %extracted_slice_8 = tensor.extract_slice %arg5[2] [1] [1] : tensor<4xi64> to tensor<1xi64>
    %extracted_9 = tensor.extract %extracted_slice_8[%c0] : tensor<1xi64>
    %14 = arith.cmpi eq, %extracted_9, %c0_i64 : i64
    %15 = linalg.fill ins(%14 : i1) outs(%2 : tensor<i1>) -> tensor<i1>
    %16 = linalg.fill ins(%extracted_9 : i64) outs(%4 : tensor<i64>) -> tensor<i64>
    %17 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = []} ins(%15, %cst_2, %16 : tensor<i1>, tensor<i64>, tensor<i64>) outs(%4 : tensor<i64>) {
    ^bb0(%in: i1, %in_26: i64, %in_27: i64, %out: i64):
      %108 = arith.select %in, %in_26, %in_27 : i64
      linalg.yield %108 : i64
    } -> tensor<i64>
    %extracted_10 = tensor.extract %17[] : tensor<i64>
    %extracted_slice_11 = tensor.extract_slice %arg5[3] [1] [1] : tensor<4xi64> to tensor<1xi64>
    %extracted_12 = tensor.extract %extracted_slice_11[%c0] : tensor<1xi64>
    %18 = arith.cmpi slt, %extracted_3, %c0_i64 : i64
    %19 = arith.select %18, %c1_i64, %extracted_3 : i64
    %20 = arith.extui %18 : i1 to i64
    %21 = arith.muli %19, %extracted_7 : i64
    %22 = arith.addi %20, %c1_i64 : i64
    %23 = arith.cmpi slt, %extracted_7, %c0_i64 : i64
    %24 = arith.select %23, %19, %21 : i64
    %25 = arith.select %23, %22, %20 : i64
    %26 = arith.muli %24, %extracted_10 : i64
    %27 = arith.addi %25, %c1_i64 : i64
    %28 = arith.cmpi slt, %extracted_10, %c0_i64 : i64
    %29 = arith.select %28, %24, %26 : i64
    %30 = arith.select %28, %27, %25 : i64
    %31 = arith.muli %29, %extracted_12 : i64
    %32 = arith.addi %30, %c1_i64 : i64
    %33 = arith.cmpi slt, %extracted_12, %c0_i64 : i64
    %34 = arith.select %33, %29, %31 : i64
    %35 = arith.select %33, %32, %30 : i64
    %36 = arith.cmpi sle, %35, %c1_i64 : i64
    cf.assert %36, "must have at most one inferred (negative) dimension"
    %37 = arith.muli %1, %9 : i64
    %38 = arith.muli %37, %c768_i64 : i64
    %39 = arith.divsi %38, %34 : i64
    %40 = arith.select %18, %39, %extracted_3 : i64
    %41 = arith.select %23, %39, %extracted_7 : i64
    %42 = arith.select %28, %39, %extracted_10 : i64
    %43 = arith.select %33, %39, %extracted_12 : i64
    %from_elements = tensor.from_elements %40, %41, %42, %43 : tensor<4xi64>
    %reshape = tensor.reshape %arg4(%from_elements) : (tensor<?x?x768xf32>, tensor<4xi64>) -> tensor<?x?x12x64xf32>
    %44 = arith.index_cast %40 : i64 to index
    %45 = arith.index_cast %41 : i64 to index
    %46 = tensor.empty(%44, %45) : tensor<?x12x?x64xf32>
    %transposed = linalg.transpose ins(%reshape : tensor<?x?x12x64xf32>) outs(%46 : tensor<?x12x?x64xf32>) permutation = [0, 2, 1, 3] 
    %47 = linalg.generic {indexing_maps = [#map1, #map2, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%transposed, %cst : tensor<?x12x?x64xf32>, tensor<f32>) outs(%46 : tensor<?x12x?x64xf32>) {
    ^bb0(%in: f32, %in_26: f32, %out: f32):
      %108 = arith.divf %in, %in_26 : f32
      linalg.yield %108 : f32
    } -> tensor<?x12x?x64xf32>
    %dim_13 = tensor.dim %arg3, %c0 : tensor<?x12x64x?xf32>
    %48 = arith.maxui %44, %dim_13 : index
    %dim_14 = tensor.dim %arg3, %c3 : tensor<?x12x64x?xf32>
    %collapsed = tensor.collapse_shape %47 [[0, 1], [2], [3]] : tensor<?x12x?x64xf32> into tensor<?x?x64xf32>
    %collapsed_15 = tensor.collapse_shape %arg3 [[0, 1], [2], [3]] : tensor<?x12x64x?xf32> into tensor<?x64x?xf32>
    %49 = arith.muli %48, %c12 : index
    %50 = tensor.empty(%49, %45, %dim_14) : tensor<?x?x?xf32>
    %51 = linalg.fill ins(%cst_0 : f32) outs(%50 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
    %52 = linalg.batch_matmul ins(%collapsed, %collapsed_15 : tensor<?x?x64xf32>, tensor<?x64x?xf32>) outs(%51 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
    %53 = arith.divui %49, %c12 : index
    %expanded = tensor.expand_shape %52 [[0, 1], [2], [3]] output_shape [%53, 12, %45, %dim_14] : tensor<?x?x?xf32> into tensor<?x12x?x?xf32>
    %dim_16 = tensor.dim %arg2, %c0 : tensor<?x?xi1>
    %54 = arith.index_cast %dim_16 : index to i64
    %55 = linalg.fill ins(%54 : i64) outs(%4 : tensor<i64>) -> tensor<i64>
    %56 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = []} ins(%3, %55, %6 : tensor<i1>, tensor<i64>, tensor<i64>) outs(%4 : tensor<i64>) {
    ^bb0(%in: i1, %in_26: i64, %in_27: i64, %out: i64):
      %108 = arith.select %in, %in_26, %in_27 : i64
      linalg.yield %108 : i64
    } -> tensor<i64>
    %extracted_17 = tensor.extract %56[] : tensor<i64>
    %dim_18 = tensor.dim %arg2, %c1 : tensor<?x?xi1>
    %57 = arith.index_cast %dim_18 : index to i64
    %58 = linalg.fill ins(%57 : i64) outs(%4 : tensor<i64>) -> tensor<i64>
    %59 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = []} ins(%10, %58, %12 : tensor<i1>, tensor<i64>, tensor<i64>) outs(%4 : tensor<i64>) {
    ^bb0(%in: i1, %in_26: i64, %in_27: i64, %out: i64):
      %108 = arith.select %in, %in_26, %in_27 : i64
      linalg.yield %108 : i64
    } -> tensor<i64>
    %extracted_19 = tensor.extract %59[] : tensor<i64>
    %60 = arith.cmpi slt, %extracted_17, %c0_i64 : i64
    %61 = arith.select %60, %c1_i64, %extracted_17 : i64
    %62 = arith.extui %60 : i1 to i64
    %63 = arith.muli %61, %extracted_19 : i64
    %64 = arith.addi %62, %c1_i64 : i64
    %65 = arith.cmpi slt, %extracted_19, %c0_i64 : i64
    %66 = arith.select %65, %61, %63 : i64
    %67 = arith.select %65, %64, %62 : i64
    %68 = arith.muli %66, %extracted_9 : i64
    %69 = arith.addi %67, %c1_i64 : i64
    %70 = arith.cmpi slt, %extracted_9, %c0_i64 : i64
    %71 = arith.select %70, %66, %68 : i64
    %72 = arith.select %70, %69, %67 : i64
    %73 = arith.muli %71, %extracted_12 : i64
    %74 = arith.addi %72, %c1_i64 : i64
    %75 = arith.select %33, %71, %73 : i64
    %76 = arith.select %33, %74, %72 : i64
    %77 = arith.cmpi sle, %76, %c1_i64 : i64
    cf.assert %77, "must have at most one inferred (negative) dimension"
    %78 = arith.muli %54, %57 : i64
    %79 = arith.divsi %78, %75 : i64
    %80 = arith.select %60, %79, %extracted_17 : i64
    %81 = arith.select %65, %79, %extracted_19 : i64
    %82 = arith.select %70, %79, %extracted_9 : i64
    %83 = arith.select %33, %79, %extracted_12 : i64
    %from_elements_20 = tensor.from_elements %80, %81, %82, %83 : tensor<4xi64>
    %reshape_21 = tensor.reshape %arg2(%from_elements_20) : (tensor<?x?xi1>, tensor<4xi64>) -> tensor<?x1x1x?xi1>
    %84 = affine.apply #map3()[%49]
    %85 = arith.index_cast %84 : index to i64
    %86 = arith.index_cast %dim_14 : index to i64
    %87 = tensor.empty() : tensor<1xi64>
    %collapsed_22 = tensor.collapse_shape %87 [] : tensor<1xi64> into tensor<i64>
    %88 = linalg.fill ins(%85 : i64) outs(%collapsed_22 : tensor<i64>) -> tensor<i64>
    %extracted_23 = tensor.extract %88[] : tensor<i64>
    %89 = arith.maxsi %extracted_23, %80 : i64
    %90 = linalg.fill ins(%41 : i64) outs(%collapsed_22 : tensor<i64>) -> tensor<i64>
    %extracted_24 = tensor.extract %90[] : tensor<i64>
    %91 = arith.maxsi %extracted_24, %c1_i64 : i64
    %92 = linalg.fill ins(%86 : i64) outs(%collapsed_22 : tensor<i64>) -> tensor<i64>
    %extracted_25 = tensor.extract %92[] : tensor<i64>
    %93 = arith.maxsi %extracted_25, %83 : i64
    %94 = arith.index_cast %89 : i64 to index
    %95 = arith.cmpi sge, %89, %c0_i64 : i64
    cf.assert %95, "unimplemented: dynamic negative broadcast sizes"
    %96 = arith.cmpi slt, %91, %c0_i64 : i64
    %97 = arith.index_cast %91 : i64 to index
    %98 = arith.select %96, %c1, %97 : index
    %99 = arith.index_cast %93 : i64 to index
    %100 = arith.cmpi sge, %93, %c0_i64 : i64
    cf.assert %100, "unimplemented: dynamic negative broadcast sizes"
    %101 = tensor.empty(%94, %98, %99) : tensor<?x12x?x?xi1>
    %102 = linalg.generic {indexing_maps = [#map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} outs(%101 : tensor<?x12x?x?xi1>) {
    ^bb0(%out: i1):
      %108 = linalg.index 0 : index
      %109 = linalg.index 3 : index
      %110 = arith.index_cast %80 : i64 to index
      %111 = arith.cmpi eq, %110, %c1 : index
      %112 = arith.select %111, %c0, %108 : index
      %113 = arith.index_cast %83 : i64 to index
      %114 = arith.cmpi eq, %113, %c1 : index
      %115 = arith.select %114, %c0, %109 : index
      %extracted_26 = tensor.extract %reshape_21[%112, %c0, %c0, %115] : tensor<?x1x1x?xi1>
      linalg.yield %extracted_26 : i1
    } -> tensor<?x12x?x?xi1>
    %103 = arith.cmpi eq, %94, %84 : index
    cf.assert %103, "mismatched size for broadcast"
    %104 = arith.cmpi eq, %98, %45 : index
    cf.assert %104, "mismatched size for broadcast"
    %105 = arith.cmpi eq, %99, %dim_14 : index
    cf.assert %105, "mismatched size for broadcast"
    %106 = tensor.empty(%94, %98, %99) : tensor<?x12x?x?xf32>
    %107 = linalg.generic {indexing_maps = [#map1, #map2, #map1, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%102, %cst_1, %expanded : tensor<?x12x?x?xi1>, tensor<f32>, tensor<?x12x?x?xf32>) outs(%106 : tensor<?x12x?x?xf32>) {
    ^bb0(%in: i1, %in_26: f32, %in_27: f32, %out: f32):
      %108 = arith.select %in, %in_26, %in_27 : f32
      linalg.yield %108 : f32
    } -> tensor<?x12x?x?xf32>
    return %107 : tensor<?x12x?x?xf32>
  }
}

getting error as

tt.mlir:200:12: error: failed to legalize operation 'stream.cmd.dispatch' that was explicitly marked illegal
    %107 = linalg.generic {indexing_maps = [#map1, #map2, #map1, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%102, %cst_1, %expanded : tensor<?x12x?x?xi1>, tensor<f32>, tensor<?x12x?x?xf32>) outs(%106 : tensor<?x12x?x?xf32>) {

this linalg IR was generated with following ONNX IR

module {
  func.func @main_graph(%arg0: !torch.vtensor<[?,?],si64>, %arg1: !torch.vtensor<[?,?],si64>, %arg2: !torch.vtensor<[?,?],i1>, %arg3: !torch.vtensor<[?,12,64,?],f32>, %arg4: !torch.vtensor<[?,?,768],f32>, %arg5: !torch.vtensor<[4],si64> , %arg6: !torch.vtensor<[?,?,12,64],f32>) -> !torch.vtensor<[?,12,?,?],f32>  attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.4.0"} {
    %1 = torch.operator "onnx.Reshape"(%arg4, %arg5) : (!torch.vtensor<[?,?,768],f32>, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,12,64],f32> 
    %2 = torch.operator "onnx.Transpose"(%1) {torch.onnx.perm = [0 : si64, 2 : si64, 1 : si64, 3 : si64]} : (!torch.vtensor<[?,?,12,64],f32>) -> !torch.vtensor<[?,12,?,64],f32> 
    %3 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__18> : tensor<f32>} : () -> !torch.vtensor<[],f32> 
    %4 = torch.operator "onnx.Div"(%2, %3) : (!torch.vtensor<[?,12,?,64],f32>, !torch.vtensor<[],f32>) -> !torch.vtensor<[?,12,?,64],f32> 
    %5 = torch.operator "onnx.MatMul"(%4, %arg3) : (!torch.vtensor<[?,12,?,64],f32>, !torch.vtensor<[?,12,64,?],f32>) -> !torch.vtensor<[?,12,?,?],f32> 
    %6 = torch.operator "onnx.Reshape"(%arg2, %arg5) : (!torch.vtensor<[?,?],i1>, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,1,1,?],i1> 
    %7 = torch.operator "onnx.Shape"(%5) : (!torch.vtensor<[?,12,?,?],f32>) -> !torch.vtensor<[4],si64> 
    %8 = torch.operator "onnx.Expand"(%6, %7) : (!torch.vtensor<[?,1,1,?],i1>, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,12,?,?],i1> 
    %9 = torch.operator "onnx.Cast"(%8) {torch.onnx.to = 9 : si64} : (!torch.vtensor<[?,12,?,?],i1>) -> !torch.vtensor<[?,12,?,?],i1> 
    %10 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__22> : tensor<f32>} : () -> !torch.vtensor<[],f32> 
    %11 = torch.operator "onnx.Where"(%9, %10, %5) : (!torch.vtensor<[?,12,?,?],i1>, !torch.vtensor<[],f32>, !torch.vtensor<[?,12,?,?],f32>) -> !torch.vtensor<[?,12,?,?],f32> 
    return %11 : !torch.vtensor<[?,12,?,?],f32>
  }
}

{-#
  dialect_resources: {
    builtin: {
      __18: "0x0800000000000041",
      __22: "0x08000000FFFF7FFF"
    }
  }
#-}

If I pass the ONNX IR directory, iree is compiling fine but when passign linalg IR, after lowing it through torch-mlir, it's failing with above error

Steps to reproduce your issue

command:

iree-compile --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=llvm-cpu --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false temp.mlir

What component(s) does this issue relate to?

Compiler

Version information

No response

Additional context

No response

pashu123 commented 1 month ago

The problematic part is

 %102 = linalg.generic {indexing_maps = [#map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} outs(%101 : tensor<?x12x?x?xi1>) {
    ^bb0(%out: i1):
      %108 = linalg.index 0 : index
      %109 = linalg.index 3 : index
      %110 = arith.index_cast %80 : i64 to index
      %111 = arith.cmpi eq, %110, %c1 : index
      %112 = arith.select %111, %c0, %108 : index
      %113 = arith.index_cast %83 : i64 to index
      %114 = arith.cmpi eq, %113, %c1 : index
      %115 = arith.select %114, %c0, %109 : index
      %extracted_26 = tensor.extract %reshape_21[%112, %c0, %c0, %115] : tensor<?x1x1x?xi1>
      linalg.yield %extracted_26 : i1
    } -> tensor<?x12x?x?xi1> 

This looks like a broadcast to me. We can directly pass %reshape_21 op to the generic rather than accessing it from outside. Could someone look at the onnx.Expand op lowering?

pdhirajkumarprasad commented 1 month ago

complete list of models failing due to this

dispatch_list.txt

zjgarvey commented 1 month ago

I think I've figure out a problematic component of the Expand lowering. Going to make sure this is an appropriate fix and will post updates.

zjgarvey commented 1 month ago

So removing some logic to take the max between the provided dim size and the input dim for broadcasting seems to unblock this issue, however, this logic is necessary for producing correct results in other cases.

onnx.Expand allows having provided shapes less than the input shape at a given dim (in which case it doesn't broadcast there, and would cause an issue if we didn't take the max).

Going to keep digging a bit.

zjgarvey commented 1 month ago

I got it to work with some extra shape help in torch-mlir.

Will update soon.

zjgarvey commented 1 month ago

Small reproducers are passing, but full models still fail on a further node.

nirvedhmeshram commented 1 month ago

If this can be fixed in the front-end then thats great but I think we should be able to support this in the compiler, in that regard I think the problem starts in the iree-codegen-tile-and-distribute-to-workgroups . I am sanity checking the input we have at this point, the output we get is not something that can be legalized. Both are provided in this gist

Here is the command I used

iree-opt tile_and_distribute_repro.mlir \
-pass-pipeline='builtin.module(hal.executable(hal.executable.variant(builtin.module(func.func(iree-codegen-tile-and-distribute-to-workgroups, canonicalize)), cse)))' \
&> ouput.mlir

CC @MaheshRavishankar to take a look at the IR as well.

zjgarvey commented 1 month ago

https://github.com/llvm/torch-mlir/pull/3756 addresses the compile failure by simplifying the IR getting generated for broadcast substantially, but slightly reduces some rare case coverage for onnx.Expand.

If you guys think that the old IR should just be supported in IREE anyway, then it might not be worth landing that PR.

zjgarvey commented 1 month ago

With this PR to clean up some of the gross shape computations that aren't simplifying at the torch level, https://github.com/llvm/torch-mlir/pull/3757, I was able to compile the failing models when returning on the first "Where" node, but interestingly was then failing to compile again when returning on the next, nearly identical "Where" node.

I think I've looked into ways to try and simplify the broadcast shapes as best as possible with the two PR's I've posted so far. I'm not sure what else we could do from the front-end.

MaheshRavishankar commented 1 month ago

@zjgarvey you are probably already looking at it, but this kind of IR is really strange

 %dim = tensor.dim %arg4, %c0 : tensor<?x?x768xf32>
    %1 = arith.index_cast %dim : index to i64
    %2 = tensor.empty() : tensor<i1>
    %3 = linalg.fill ins(%0 : i1) outs(%2 : tensor<i1>) -> tensor<i1>
    %4 = tensor.empty() : tensor<i64>
    %5 = linalg.fill ins(%1 : i64) outs(%4 : tensor<i64>) -> tensor<i64>
    %6 = linalg.fill ins(%extracted : i64) outs(%4 : tensor<i64>) -> tensor<i64>
    %7 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = []} ins(%3, %5, %6 : tensor<i1>, tensor<i64>, tensor<i64>) outs(%4 : tensor<i64>) {
    ^bb0(%in: i1, %in_26: i64, %in_27: i64, %out: i64):
      %108 = arith.select %in, %in_26, %in_27 : i64
      linalg.yield %108 : i64
    } -> tensor<i64>
    %extracted_3 = tensor.extract %7[] : tensor<i64>

This is taking a dim of a tensor, creating a new tensor and inserting into it, then performing a linalg.generic operation to do a select? and then extracting it out..

Its basically shape computation that is artificially put into tensor math... IREE tries to do all tensor math on the device. So all of this computation, which is really shape computation that should be done on the host, is being transfered into device and then things go haywire cause then it is artifically looking like an indirect dispatch problem where the shape computation is dependent on previous computation on the device. The easiest fix is that the front end needs to not try to do shape computation as tensor math.

nirvedhmeshram commented 1 month ago

We need to generate new linalg IR for this issue with https://github.com/llvm/torch-mlir/pull/3762 , since the ONXX IR has been already working when run directly from IREE

zjgarvey commented 1 month ago

This is also resolved with the change https://github.com/llvm/torch-mlir/pull/3756 so it might not reproduce an issue in IREE.