openxla / stablehlo

Backward compatible ML compute opset inspired by HLO/MHLO
Apache License 2.0
410 stars 112 forks source link

WhileOp prettyprinting is incorrect for some dynamic programs #871

Open burmako opened 1 year ago

burmako commented 1 year ago
func.func public @main(%arg0: tensor<5xf32>) -> (tensor<?xf32>, tensor<i64>) {
  %0 = stablehlo.constant dense<0> : tensor<i64>
  %1:2 = "stablehlo.while"(%arg0, %0) ({
  ^bb0(%arg2: tensor<?xf32>, %arg3: tensor<i64>):
    %2 = stablehlo.constant dense<5> : tensor<i64>
    %3 = stablehlo.compare LT, %arg3, %2, SIGNED : (tensor<i64>, tensor<i64>) -> tensor<i1>
    stablehlo.return %3 : tensor<i1>
  }, {
  ^bb0(%arg2: tensor<?xf32>, %arg3: tensor<i64>):
    %2 = stablehlo.constant dense<1> : tensor<i64>
    %3 = stablehlo.add %arg3, %2 : tensor<i64>
    stablehlo.return %arg2, %3 : tensor<?xf32>, tensor<i64>
  }) : (tensor<5xf32>, tensor<i64>) -> (tensor<?xf32>, tensor<i64>)
  return %1#0, %1#1 : tensor<?xf32>, tensor<i64>
}

is incorrectly prettyprinted as:

func.func public @main(%arg0: tensor<5xf32>) -> (tensor<?xf32>, tensor<i64>) {
  %0 = stablehlo.constant dense<0> : tensor<i64>
  %1:2 = stablehlo.while(%iterArg = %arg0, %iterArg_0 = %0) : tensor<5xf32>, tensor<i64>
   cond {
    %2 = stablehlo.constant dense<5> : tensor<i64>
    %3 = stablehlo.compare  LT, %iterArg_0, %2,  SIGNED : (tensor<i64>, tensor<i64>) -> tensor<i1>
    stablehlo.return %3 : tensor<i1>
  } do {
    %2 = stablehlo.constant dense<1> : tensor<i64>
    %3 = stablehlo.add %iterArg_0, %2 : tensor<i64>
    stablehlo.return %iterArg, %3 : tensor<?xf32>, tensor<i64>
  }
  return %1#0, %1#1 : tensor<?xf32>, tensor<i64>
}

Note that the fact that the operand type of the WhileOp is different from the argument type of the cond and body regions is lost in the prettyprinted version. As a result, the prettyprinted version will fail to parse:

~/example.mlir:27:17: error: use of value '%iterArg' expects different type than prior uses: 'tensor<?xf32>' vs 'tensor<5xf32>'
    mhlo.return %iterArg, %3 : tensor<?xf32>, tensor<i64>
                ^
~/example.mlir:19:21: note: prior use here
  %1:2 = mhlo.while(%iterArg = %arg0, %iterArg_0 = %0) : tensor<5xf32>, tensor<i64>
                    ^
burmako commented 1 year ago

Unassigning in favor of higher-priority work. This issue made developing --stablehlo-shape-refinement somewhat inconvenient, but it no longer represents an acute problem.