openxla / stablehlo

Backward compatible ML compute opset inspired by HLO/MHLO
Apache License 2.0
390 stars 105 forks source link

ReduceOp prettyprint is different after `--strip-debuginfo` #1523

Closed GleasonK closed 9 months ago

GleasonK commented 1 year ago

Request description

This happens because of: https://github.com/openxla/stablehlo/blob/4cd6f24257a364857add0e3c1dc11b2364669d50/stablehlo/dialect/StablehloOps.cpp#L1489-L1490

Repro:

$ cat /tmp/t.mlir 
func.func @op_reduce(%arg0: tensor<16xf32>, %arg1: tensor<f32>) -> tensor<f32> {
  %0 = "stablehlo.reduce"(%arg0, %arg1) ({
    ^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
      %1 = "stablehlo.add"(%arg2, %arg3) : (tensor<f32>, tensor<f32>) -> tensor<f32>
      "stablehlo.return"(%1) : (tensor<f32>) -> ()
  }) {
    dimensions = dense<0> : tensor<1xi64>
  } : (tensor<16xf32>, tensor<f32>) -> tensor<f32>
  func.return %0 : tensor<f32>
}

$ stablehlo-opt /tmp/t.mlir
module {
  func.func @op_reduce(%arg0: tensor<16xf32>, %arg1: tensor<f32>) -> tensor<f32> {
    %0 = stablehlo.reduce(%arg0 init: %arg1) across dimensions = [0] : (tensor<16xf32>, tensor<f32>) -> tensor<f32>
     reducer(%arg2: tensor<f32>, %arg3: tensor<f32>)  {
      %1 = stablehlo.add %arg2, %arg3 : tensor<f32>
      stablehlo.return %1 : tensor<f32>
    }
    return %0 : tensor<f32>
  }
}

$ stablehlo-opt /tmp/t.mlir --strip-debuginfo
module {
  func.func @op_reduce(%arg0: tensor<16xf32>, %arg1: tensor<f32>) -> tensor<f32> {
    %0 = stablehlo.reduce(%arg0 init: %arg1) applies stablehlo.add across dimensions = [0] : (tensor<16xf32>, tensor<f32>) -> tensor<f32>
    return %0 : tensor<f32>
  }
}

This is probably the primary reason most ReduceOp's aren't prettyprinted. It is unclear if this is correct/inocrrect behavior, it is very uncommon for op prettyprints to be based on debug info, but reduce is more complicated so will require some investigation.

fzakaria commented 9 months ago

Just to add some clarification:

It wasn't clear from the description which of the two alternatives are preferred.

fzakaria commented 9 months ago

More information after having dug in:

The test for compact pretty printing validates that the Location of blocks, arguments & inner-op are all the same. The codebase includes a test to validate this:

func.func @reduce_one_op_all_locs_same(%arg0: tensor<?x?xf32>, %arg1 : tensor<f32>) -> (tensor<?xf32>) {
  %0 = "stablehlo.reduce"(%arg0, %arg1) ({
  ^bb0(%arg2: tensor<f32> loc("foo"), %arg3: tensor<f32> loc("foo")):
    %1 = "stablehlo.add"(%arg2, %arg3) : (tensor<f32>, tensor<f32>) -> tensor<f32> loc("foo")
    "stablehlo.return"(%1) : (tensor<f32>) -> () loc("foo")
  }) {dimensions = dense<[1]> : tensor<1xi64>} : (tensor<?x?xf32>, tensor<f32>) -> tensor<?xf32> loc("foo")

  func.return %0: tensor<?xf32>
}

When no location information is present, MLIR includes source level information when debug info is present. For example:

#loc2 = loc("foo")
#loc4 = loc("/tmp/1523.mlir":3:10)
#loc5 = loc("/tmp/1523.mlir":3:30)
module {
  func.func @op_reduce(%arg0: tensor<16xf32> loc("foo"), %arg1: tensor<f32> loc("foo")) -> tensor<f32> {
    %0 = stablehlo.reduce(%arg0 init: %arg1) across dimensions = [0] : (tensor<16xf32>, tensor<f32>) -> tensor<f32>
     reducer(%arg2: tensor<f32> loc("/tmp/1523.mlir":3:10), %arg3: tensor<f32> loc("/tmp/1523.mlir":3:30))  {
      %1 = stablehlo.add %arg2, %arg3 : tensor<f32> loc(#loc6)
      stablehlo.return %1 : tensor<f32> loc(#loc7)
    } loc(#loc3)
    return %0 : tensor<f32> loc(#loc8)
  } loc(#loc1)
} loc(#loc)
#loc = loc("/tmp/1523.mlir":0:0)
#loc1 = loc("/tmp/1523.mlir":1:1)
#loc3 = loc("/tmp/1523.mlir":2:8)
#loc6 = loc("/tmp/1523.mlir":4:12)
#loc7 = loc("/tmp/1523.mlir":5:7)
#loc8 = loc("/tmp/1523.mlir":9:3)

In this case it can never then be equal and the compact pretty printing wouldn't occur.

The code explicitly tests for this but I'm not clear as to the reason why the constraint is added.

GleasonK commented 9 months ago

There is a tradeoff, we will elide some location info from the compact prettyprint when there is a single-op reducer. This can show up in cases like the following:

# Doesn't use compact pretty print
def reduce_not_inline(x):
  reducer_fn = lambda x,y: jax.lax.add(x,y)
  return jax.lax.reduce(x, 0.0, reducer_fn, [1, 2])

# Does use compact pretty print
def reduce_inline(x):
  return jax.lax.reduce(x, 0.0, jax.lax.add, [1, 2])

In PyTorch/XLA I don't think there is a way currently to export reduce ops in a way that is pretty printed, these reduce ops in PT/XLA are almost always single-op reductions as well.

My opinion is that this debuginfo is not very useful, and there isn't a guarantee that a prettyprint preserves all debuginfo, for example the upstream SCF ForOp doesn't check that region arguments location info matches the op location info before eliding. In practice, use cases that care about debuginfo pass bytecode or occasionally generic print. Given that this only impacts single-op reductions, this is unlikely to impact JAX users that have custom reduction functions, and would benefit PT/XLA users in terms of readability.

sdasgup3 commented 9 months ago

The same location constraint is added to make sure that the compact version of the reduce op is perfectly round-trippable (compact reduce op --> generic reduce op -> compact reduce op). As %retval = mhlo.reduce(%a, %b) applies mhlo.maximum across dimensions = [0] : (tensor<nxi32>, tensor<i32>) -> tensor<i32> has a single location so we concluded that the compact version of reduce op would be roundtrippable only if all the location information are the same.

fzakaria commented 9 months ago

@GleasonK looking at your example and the output; JAX is getting the compact version strictly because the location info being emitted is coarser; it does not include column information.

#loc82 = loc("jit(my_reduce_add)/jit(main)/reduce[computation=<function my_reduce_add.<locals>.<lambda> at 0x7f197edba840> dimensions=(1, 2)]"(#loc81))
module @jit_my_reduce_add attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<2x5x5x3xf32> {mhlo.layout_mode = "default", mhlo.sharding = "{replicated}"} loc("x")) -> (tensor<2x3xf32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    %0 = stablehlo.constant dense<0.000000e+00> : tensor<f32> loc(#loc)
    %1 = stablehlo.reduce(%arg0 init: %0) across dimensions = [1, 2] : (tensor<2x5x5x3xf32>, tensor<f32>) -> tensor<2x3xf32>
     reducer(%arg1: tensor<f32> loc("jit(my_reduce_add)/jit(main)/reduce[computation=<function my_reduce_add.<locals>.<lambda> at 0x7f197edba840> dimensions=(1, 2)]"(#loc81)), %arg2: tensor<f32> loc("jit(my_reduce_add)/jit(main)/reduce[computation=<function my_reduce_add.<locals>.<lambda> at 0x7f197edba840> dimensions=(1, 2)]"(#loc81)))  {
      %2 = stablehlo.add %arg1, %arg2 : tensor<f32> loc(#loc84)
      stablehlo.return %2 : tensor<f32> loc(#loc82)
    } loc(#loc82)
    return %1 : tensor<2x3xf32> loc(#loc)
  } loc(#loc)
} loc(#loc)

Relevant line is loc("jit(my_reduce_add)/jit(main)/reduce[computation=<function my_reduce_add.<locals>.<lambda> at 0x7f197edba840> dimensions=(1, 2)]"(#loc81)) (interesting why the alias isn't replaced). The output does not include column information which is why they are equal.

@sdasgup3 isomorphic requirement can only ever be possible if your orginal Location information is coarse. This seems like "working as intended" but in any case an odd transformation; I'm not sure you can guarantee isomorphism with debug Location information if you collect it with column values.

My recommendation is to close as "working as intended" but it would be great to better understand if transformations are guaranteed to preserve Location information forwards and backwards

fzakaria commented 9 months ago

Closing this issue. Please see #1906 as to the rationale.