iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.52k stars 557 forks source link

StableHLOToLinalgConvolution hits an assert (StableHLO -> Linalg failure) #17714

Open banach-space opened 1 month ago

banach-space commented 1 month ago

IREE SHA: d792d2483a9d5d875a0f03b57c9d4f98e072c300

Reproducer

Input extracted from ResNet. I imported it to StableHLO following https://openxla.org/stablehlo/tutorials/jax-export:

module @jit__unnamed_wrapped_function_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<1x3x224x224xf32> {mhlo.layout_mode = "default"}, %filter: tensor<7x7x3x64xf32>) -> tensor<1x112x112x64xf32> {
    %0 = stablehlo.transpose %arg0, dims = [0, 2, 3, 1] : (tensor<1x3x224x224xf32>) -> tensor<1x224x224x3xf32>
    %1 = stablehlo.convolution(%0, %filter) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {stride = [2, 2], pad = [[3, 3], [3, 3]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x224x224x3xf32>, tensor<7x7x3x64xf32>) -> tensor<1x112x112x64xf32>
    return %1 : tensor<1x112x112x64xf32>
  }
}

Run:

$ iree-compile conv_repro_stablehlo.mlir

Analysis Based on the stack trace, this example hit's an assert when running mlir::linalg::Conv2DNhwcHwcfOp::build. Assert message:

OperationSupport.h:1044: void mlir::OperationState::addAttribute(StringAttr, Attribute): Assertion `attr && "attribute cannot be null"' failed.

I'm not familiar with StableHLO, so I'm not sure whether:

Reporting in case this is a genuine IREE issue worth addressing.

My "daft" Workaround

--- a/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalgConvolution.cpp
+++ b/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalgConvolution.cpp
@@ -241,6 +241,10 @@ struct NormalConvolutionOpConversion final
       break;
     }
     case 4: {
+      if (!dilations)
+        dilations = rewriter.getI64TensorAttr({1, 1});
+      if (!strides)
+        strides = rewriter.getI64TensorAttr({1, 1});
       res = rewriter.create<linalg::Conv2DNhwcHwcfOp>(
           loc, resultType, ValueRange{input, filter}, ValueRange{zeroTensor},
           strides, dilations, linalg::getPrunedAttributeList(op));
jpienaar commented 1 month ago

If it matches https://github.com/openxla/xla/blob/f292b892846960fff42e0e9c11f9a96f22723bd7/xla/mlir_hlo/mhlo/transforms/legalize_to_linalg/legalize_to_linalg.cc#L2924 then it may be an export error (could still be lowering to linalg failure - also the surprised the named Linalg op doesn't have proper defaults for these).

banach-space commented 1 month ago

The logic in IREE is almost identical:

surprised the named Linalg op doesn't have proper defaults for these

indeed