Open mmanzoorTT opened 2 days ago
tt-metal only supports embedding for bfloat16 data type. We have a use case in our tt-torch models where the input operand is float32 which causes failure. The stablehlo graph is below
tt-metal
tt-torch
module { func.func @main(%arg0: tensor<2048x32xf32>, %arg1: tensor<1x5xi64>) -> tensor<1x5x32xf32> { %0 = stablehlo.reshape %arg1 : (tensor<1x5xi64>) -> tensor<1x5x1xi64> %1 = "stablehlo.gather"(%arg0, %0) <{dimension_numbers = #stablehlo.gather<offset_dims = [2], collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 2>, indices_are_sorted = false, slice_sizes = array<i64: 1, 32>}> : (tensor<2048x32xf32>, tensor<1x5x1xi64>) -> tensor<1x5x32xf32> return %1 : tensor<1x5x32xf32> } }
We may add a typecast to convert input operand to bfloat16.
tt-metal
only supports embedding for bfloat16 data type. We have a use case in ourtt-torch
models where the input operand is float32 which causes failure. The stablehlo graph is belowWe may add a typecast to convert input operand to bfloat16.