tenstorrent / tt-mlir

Tenstorrent MLIR compiler
https://tenstorrent.github.io/tt-mlir/
Apache License 2.0
78 stars 13 forks source link

TT-xla not supporting scalar values #1306

Open ajakovljevicTT opened 1 week ago

ajakovljevicTT commented 1 week ago

Due to the fact that TTIR does not support scalar values, but internally translates them to one-dimensional arrays, the TT-xla frontend will generate some unnecessary stablehlo.reshape ops, which will then fail due to their arguments being scalars, which means that we are not able to test any function that had a a scalar return directly in tt-xla. We need to change the way we lower the stablehlo.reshape (and possibly omit it if it is not needed) to fix this.

The example of the reshape stableHLO dialect code that tt-xla produces and currently throws an error.

module @jit_ravel attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<i32>) -> (tensor<1xi32> {jax.result_info = ""}) {
    %0 = stablehlo.reshape %arg0 : (tensor<i32>) -> tensor<1xi32>
    return %0 : tensor<1xi32>
  }

Which currently throws an error:

test_constant.mlir:2:26: error: failed to legalize unresolved materialization from ('tensor<1xi32>') to 'tensor<i32>' that remained live after conversion
  func.func public @main(%arg0: tensor<i32>) -> (tensor<1xi32> {jax.result_info = ""}) {