tenstorrent / tt-mlir

Tenstorrent MLIR compiler
https://tenstorrent.github.io/tt-mlir/
Apache License 2.0
76 stars 13 forks source link

[StableHLO] Unhandled use cases for stablehlo.gather op #1350

Open mmanzoorTT opened 2 days ago

mmanzoorTT commented 2 days ago

Following are the stablehlo graphs for stablehlo.gather op (coming from PyTorch models through tt-torch) which tt-mlir can't handle and mark them illegal explicitly.

Example 1:

func.func @test1(%arg0: tensor<1x7x2xbf16>, %arg1: tensor<1xi64>, %arg2: tensor<1xi64>) -> tensor<1x2xbf16> {
    %0 = stablehlo.reshape %arg1 : (tensor<1xi64>) -> tensor<1x1xi64>
    %1 = stablehlo.reshape %arg2 : (tensor<1xi64>) -> tensor<1x1xi64>
    %2 = stablehlo.concatenate %0, %1, dim = 1 : (tensor<1x1xi64>, tensor<1x1xi64>) -> tensor<1x2xi64>
    %3 = "stablehlo.gather"(%arg0, %2) <{dimension_numbers = #stablehlo.gather<offset_dims = [1], collapsed_slice_dims = [0, 1], start_index_map = [0, 1], index_vector_dim = 1>, indices_are_sorted = false, slice_sizes = array<i64: 1, 1, 2>}> : (tensor<1x7x2xbf16>, tensor<1x2xi64>) -> tensor<1x2xbf16>
    return %3 : tensor<1x2xbf16>
  }

Example 2:

  func.func @test2(%arg0: tensor<2x7x512xbf16>, %arg1: tensor<2xi64>, %arg2: tensor<2xi64>) -> tensor<2x512xbf16> {
    %0 = stablehlo.reshape %arg1 : (tensor<2xi64>) -> tensor<2x1xi64>
    %1 = stablehlo.reshape %arg2 : (tensor<2xi64>) -> tensor<2x1xi64>
    %2 = "stablehlo.concatenate" (%0, %1) {dimension = 1 : i64 } : (tensor<2x1xi64>, tensor<2x1xi64>) -> tensor<2x2xi64>
    %3 = "stablehlo.gather"(%arg0, %2) <{dimension_numbers = #stablehlo.gather<offset_dims = [1], collapsed_slice_dims = [0, 1], start_index_map = [0, 1], index_vector_dim = 1>, indices_are_sorted = false, slice_sizes = array<i64: 1, 1, 512>}> : (tensor<2x7x512xbf16>, tensor<2x2xi64>) -> tensor<2x512xbf16>
    return %3 : tensor<2x512xbf16>
  }

Example 3:

  func.func @test3(%arg0: tensor<732x12xbf16>, %arg1: tensor<38809xi64>) -> tensor<38809x12xbf16> {
    %0 = stablehlo.reshape %arg1 : (tensor<38809xi64>) -> tensor<38809x1xi64>
    %1 = "stablehlo.gather"(%arg0, %0) <{dimension_numbers = #stablehlo.gather<offset_dims = [1], collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 1>, indices_are_sorted = false, slice_sizes = array<i64: 1, 12>}> : (tensor<732x12xbf16>, tensor<38809x1xi64>) -> tensor<38809x12xbf16>
    return %1 : tensor<38809x12xbf16>
  }

Example 4:

  func.func @test4(%arg0: tensor<732x16xbf16>, %arg1: tensor<38809xi64>) -> tensor<38809x16xbf16> {
    %0 = stablehlo.reshape %arg1 : (tensor<38809xi64>) -> tensor<38809x1xi64>
    %1 = "stablehlo.gather"(%arg0, %0) <{dimension_numbers = #stablehlo.gather<offset_dims = [1], collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 1>, indices_are_sorted = false, slice_sizes = array<i64: 1, 16>}> : (tensor<732x16xbf16>, tensor<38809x1xi64>) -> tensor<38809x16xbf16>
    return %1 : tensor<38809x16xbf16>
  }
ddilbazTT commented 2 days ago

These might be cases where we convert gather to reshape + slice + concatenate.

These examples are against the constraints we check for lowering to embedding: