tenstorrent / tt-mlir

Tenstorrent MLIR compiler
https://tenstorrent.github.io/tt-mlir/
Apache License 2.0
76 stars 13 forks source link

[StableHLO] gather op failure for unsupported data types #1336

Closed mmanzoorTT closed 1 day ago

mmanzoorTT commented 2 days ago

StableHLO->TTIR conversion fails for gather op for unsupported HW datatypes (e.g. boolean, i64, f64). stablehlo.gather is lowered to ttnn.embedding which only support bfloat16 data type. Add verifier for input data type and update test cases.

Sample graph for unsupported data type.

module {
  func.func @main(%arg0: tensor<250002x768xbf16>, %arg1: tensor<1x10xi64>) -> tensor<1x10x768xbf16> {
    %0 = ""stablehlo.gather""(%arg0, %arg1) <{dimension_numbers = #stablehlo.gather<offset_dims = [2], collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 2>, indices_are_sorted = false, slice_sizes = array<i64: 1, 768>}> : (tensor<250002x768xbf16>, tensor<1x10xi64>) -> tensor<1x10x768xbf16>
    %1 = stablehlo.convert %0 : tensor<1x10x768xbf16>
    return %1 : tensor<1x10x768xbf16>
  }
}

Error message:

error: failed to legalize operation 'stablehlo.gather'  "results/mlir_tests/stable_hlo/aten::embedding_0.mlir:3:10: error: failed to legalize operation 'stablehlo.gather'
    %0 = ""stablehlo.gather""(%arg0, %arg1) <{dimension_numbers = #stablehlo.gather<offset_dims = [2], collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 2>, indices_are_sorted = false, slice_sizes = array<i64: 1, 768>}> : (tensor<250002x768xbf16>, tensor<1x10xi64>) -> tensor<1x10x768xbf16>
         ^
results/mlir_tests/stable_hlo/aten::embedding_0.mlir:3:10: note: see current operation: %2 = ""stablehlo.gather""(<<UNKNOWN SSA VALUE>>, <<UNKNOWN SSA VALUE>>) <{dimension_numbers = #stablehlo.gather<offset_dims = [2], collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 2>, indices_are_sorted = false, slice_sizes = array<i64: 1, 768>}> : (tensor<250002x768xbf16>, tensor<1x10xi64>) -> tensor<1x10x768xbf16>