Open mmanzoorTT opened 2 days ago
Following are the stablehlo graphs for stablehlo.gather op (coming from PyTorch models through tt-torch) which tt-mlir can't handle and mark them illegal explicitly.
stablehlo.gather
tt-torch
tt-mlir
Example 1:
func.func @test1(%arg0: tensor<1x7x2xbf16>, %arg1: tensor<1xi64>, %arg2: tensor<1xi64>) -> tensor<1x2xbf16> { %0 = stablehlo.reshape %arg1 : (tensor<1xi64>) -> tensor<1x1xi64> %1 = stablehlo.reshape %arg2 : (tensor<1xi64>) -> tensor<1x1xi64> %2 = stablehlo.concatenate %0, %1, dim = 1 : (tensor<1x1xi64>, tensor<1x1xi64>) -> tensor<1x2xi64> %3 = "stablehlo.gather"(%arg0, %2) <{dimension_numbers = #stablehlo.gather<offset_dims = [1], collapsed_slice_dims = [0, 1], start_index_map = [0, 1], index_vector_dim = 1>, indices_are_sorted = false, slice_sizes = array<i64: 1, 1, 2>}> : (tensor<1x7x2xbf16>, tensor<1x2xi64>) -> tensor<1x2xbf16> return %3 : tensor<1x2xbf16> }
Example 2:
func.func @test2(%arg0: tensor<2x7x512xbf16>, %arg1: tensor<2xi64>, %arg2: tensor<2xi64>) -> tensor<2x512xbf16> { %0 = stablehlo.reshape %arg1 : (tensor<2xi64>) -> tensor<2x1xi64> %1 = stablehlo.reshape %arg2 : (tensor<2xi64>) -> tensor<2x1xi64> %2 = "stablehlo.concatenate" (%0, %1) {dimension = 1 : i64 } : (tensor<2x1xi64>, tensor<2x1xi64>) -> tensor<2x2xi64> %3 = "stablehlo.gather"(%arg0, %2) <{dimension_numbers = #stablehlo.gather<offset_dims = [1], collapsed_slice_dims = [0, 1], start_index_map = [0, 1], index_vector_dim = 1>, indices_are_sorted = false, slice_sizes = array<i64: 1, 1, 512>}> : (tensor<2x7x512xbf16>, tensor<2x2xi64>) -> tensor<2x512xbf16> return %3 : tensor<2x512xbf16> }
Example 3:
func.func @test3(%arg0: tensor<732x12xbf16>, %arg1: tensor<38809xi64>) -> tensor<38809x12xbf16> { %0 = stablehlo.reshape %arg1 : (tensor<38809xi64>) -> tensor<38809x1xi64> %1 = "stablehlo.gather"(%arg0, %0) <{dimension_numbers = #stablehlo.gather<offset_dims = [1], collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 1>, indices_are_sorted = false, slice_sizes = array<i64: 1, 12>}> : (tensor<732x12xbf16>, tensor<38809x1xi64>) -> tensor<38809x12xbf16> return %1 : tensor<38809x12xbf16> }
Example 4:
func.func @test4(%arg0: tensor<732x16xbf16>, %arg1: tensor<38809xi64>) -> tensor<38809x16xbf16> { %0 = stablehlo.reshape %arg1 : (tensor<38809xi64>) -> tensor<38809x1xi64> %1 = "stablehlo.gather"(%arg0, %0) <{dimension_numbers = #stablehlo.gather<offset_dims = [1], collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 1>, indices_are_sorted = false, slice_sizes = array<i64: 1, 16>}> : (tensor<732x16xbf16>, tensor<38809x1xi64>) -> tensor<38809x16xbf16> return %1 : tensor<38809x16xbf16> }
These might be cases where we convert gather to reshape + slice + concatenate.
These examples are against the constraints we check for lowering to embedding:
Following are the stablehlo graphs for
stablehlo.gather
op (coming from PyTorch models throughtt-torch
) whichtt-mlir
can't handle and mark them illegal explicitly.Example 1:
Example 2:
Example 3:
Example 4: