tenstorrent / tt-mlir

Tenstorrent MLIR compiler
https://tenstorrent.github.io/tt-mlir/
Apache License 2.0
76 stars 13 forks source link

[StableHLO] Reshape op failure for unsupported data types #1317

Closed mmanzoorTT closed 3 days ago

mmanzoorTT commented 4 days ago

StableHLO->TTIR conversion fails for reshape op when it tries to convert unsupported datatypes (e.g. boolean, i64, f64). Some example stablehlo graph

func.func @main(%arg0: tensor<1x1xi64>) -> tensor<1xi64> {
  %0 = stablehlo.reshape %arg0 : (tensor<1x1xi64>) -> tensor<1xi64>
  return %0 : tensor<1xi64>
}
func.func @main(%arg0:   tensor<1x45xi64>) -> tensor<45xi64> {
  %0 = stablehlo.reshape %arg0 :   (tensor<1x45xi64>) -> tensor<45xi64>
  return %0 :   tensor<45xi64>
}

Error message

error: failed to legalize unresolved materialization from ('tensor<1x1xi32>') to 'tensor<1x1xi64>' that remained live after conversion
  func.func @main(%arg0: tensor<1x1xi64>) -> tensor<1xi64> {
                  ^
tt-mlir/test/ttmlir/Conversion/StableHLOToTTIR/test.mlir:4:19: note: see current operation: %0 = "builtin.unrealized_conversion_cast"(%arg0) : (tensor<1x1xi32>) -> tensor<1x1xi64>
tt-mlir/test/ttmlir/Conversion/StableHLOToTTIR/test.mlir:5:10: note: see existing live user here: %2 = "ttir.reshape"(%0, %1) <{operand_constraints = [#tt.operand_constraint<dram|l1|tile|none|interleaved|single_bank|height_sharded|width_sharded|block_sharded|any_layout|any_device_tile>, #tt.operand_constraint<dram|l1|tile|none|interleaved|single_bank|height_sharded|width_sharded|block_sharded|any_layout|any_device_tile>], shape = [1 : i32]}> : (tensor<1x1xi64>, tensor<1xi32>) -> tensor<1xi32>
    %0 = stablehlo.reshape %arg0 : (tensor<1x1xi64>) -> tensor<1xi64>
         ^