Closed mmanzoorTT closed 1 day ago
StableHLO->TTIR conversion fails for gather op for unsupported HW datatypes (e.g. boolean, i64, f64). stablehlo.gather is lowered to ttnn.embedding which only support bfloat16 data type. Add verifier for input data type and update test cases.
Sample graph for unsupported data type.
module { func.func @main(%arg0: tensor<250002x768xbf16>, %arg1: tensor<1x10xi64>) -> tensor<1x10x768xbf16> { %0 = ""stablehlo.gather""(%arg0, %arg1) <{dimension_numbers = #stablehlo.gather<offset_dims = [2], collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 2>, indices_are_sorted = false, slice_sizes = array<i64: 1, 768>}> : (tensor<250002x768xbf16>, tensor<1x10xi64>) -> tensor<1x10x768xbf16> %1 = stablehlo.convert %0 : tensor<1x10x768xbf16> return %1 : tensor<1x10x768xbf16> } }
Error message:
error: failed to legalize operation 'stablehlo.gather' "results/mlir_tests/stable_hlo/aten::embedding_0.mlir:3:10: error: failed to legalize operation 'stablehlo.gather' %0 = ""stablehlo.gather""(%arg0, %arg1) <{dimension_numbers = #stablehlo.gather<offset_dims = [2], collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 2>, indices_are_sorted = false, slice_sizes = array<i64: 1, 768>}> : (tensor<250002x768xbf16>, tensor<1x10xi64>) -> tensor<1x10x768xbf16> ^ results/mlir_tests/stable_hlo/aten::embedding_0.mlir:3:10: note: see current operation: %2 = ""stablehlo.gather""(<<UNKNOWN SSA VALUE>>, <<UNKNOWN SSA VALUE>>) <{dimension_numbers = #stablehlo.gather<offset_dims = [2], collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 2>, indices_are_sorted = false, slice_sizes = array<i64: 1, 768>}> : (tensor<250002x768xbf16>, tensor<1x10xi64>) -> tensor<1x10x768xbf16>
StableHLO->TTIR conversion fails for gather op for unsupported HW datatypes (e.g. boolean, i64, f64). stablehlo.gather is lowered to ttnn.embedding which only support bfloat16 data type. Add verifier for input data type and update test cases.
Sample graph for unsupported data type.
Error message: