tenstorrent / tt-mlir

Tenstorrent MLIR compiler
https://tenstorrent.github.io/tt-mlir/
Apache License 2.0
76 stars 13 forks source link

[StableHLO] gather op with float32 operand #1349

Open mmanzoorTT opened 2 days ago

mmanzoorTT commented 2 days ago

tt-metal only supports embedding for bfloat16 data type. We have a use case in our tt-torch models where the input operand is float32 which causes failure. The stablehlo graph is below

module {
  func.func @main(%arg0: tensor<2048x32xf32>, %arg1: tensor<1x5xi64>) -> tensor<1x5x32xf32> {
    %0 = stablehlo.reshape %arg1 : (tensor<1x5xi64>) -> tensor<1x5x1xi64>
    %1 = "stablehlo.gather"(%arg0, %0) <{dimension_numbers = #stablehlo.gather<offset_dims = [2], collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 2>, indices_are_sorted = false, slice_sizes = array<i64: 1, 32>}> : (tensor<2048x32xf32>, tensor<1x5x1xi64>) -> tensor<1x5x32xf32>
    return %1 : tensor<1x5x32xf32>
  }
}

We may add a typecast to convert input operand to bfloat16.