tenstorrent / tt-mlir

Tenstorrent MLIR compiler
https://tenstorrent.github.io/tt-mlir/
Apache License 2.0
76 stars 13 forks source link

[StableHLO] failed to legalize operation ttir.convolution #1342

Open mmanzoorTT opened 2 days ago

mmanzoorTT commented 2 days ago

stablehlo.convolution op is lowered to ttir.convolution but lowering to TTNN fails because ttir.convolution op is marked illegal.

module {
  func.func @main(%arg0: tensor<1x256x512xbf16>, %arg1: tensor<1024x256x1xbf16>, %arg2: tensor<1024xbf16>) -> tensor<1x1024x512xbf16> {
    %0 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, f, 0]x[o, i, 0]->[b, f, 0], window = {stride = [1], pad = [[0, 0]], rhs_dilate = [1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x256x512xbf16>, tensor<1024x256x1xbf16>) -> tensor<1x1024x512xbf16>
    %1 = stablehlo.reshape %arg2 : (tensor<1024xbf16>) -> tensor<1024x1xbf16>
    %2 = stablehlo.broadcast_in_dim %0, dims = [0, 1, 2] : (tensor<1x1024x512xbf16>) -> tensor<1x1024x512xbf16>
    %3 = stablehlo.broadcast_in_dim %1, dims = [1, 2] : (tensor<1024x1xbf16>) -> tensor<1x1024x512xbf16>
    %4 = stablehlo.add %2, %3 : tensor<1x1024x512xbf16>
    return %4 : tensor<1x1024x512xbf16>
  }
}

TTIR grap

"#any_device_tile = #tt.operand_constraint<dram|l1|tile|none|interleaved|single_bank|height_sharded|width_sharded|block_sharded|any_layout|any_device_tile>
module {
  func.func @main(%arg0: tensor<1x256x512xbf16>, %arg1: tensor<1024x256x1xbf16>, %arg2: tensor<1024xbf16>) -> tensor<1x1024x512xbf16> {
    %0 = tensor.empty() : tensor<1x1024x512xbf16>
    %1 = ""ttir.convolution""(%arg0, %arg1, %0) <{batch_group_count = 1 : i64, convolution_layout = #ttir<convolution_layout input_batch = 0, input_feature = 1, input_spatial_dimensions = 2, kernel_output_feature = 0, kernel_input_feature = 1, kernel_spatial_dimensions = 2, output_batch = 0, output_feature = 1, output_spatial_dimensions = 2>, feature_group_count = 1 : i64, input_dilation = array<i64: 1>, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile], padding = array<i64: 0, 0>, weight_dilation = array<i64: 1>, window_reversal = array<i1: false>, window_strides = array<i64: 1>}> : (tensor<1x256x512xbf16>, tensor<1024x256x1xbf16>, tensor<1x1024x512xbf16>) -> tensor<1x1024x512xbf16>
    %2 = tensor.empty() : tensor<1024x1xbf16>
    %3 = ""ttir.reshape""(%arg2, %2) <{operand_constraints = [#any_device_tile, #any_device_tile], shape = [1024 : i32, 1 : i32]}> : (tensor<1024xbf16>, tensor<1024x1xbf16>) -> tensor<1024x1xbf16>
    %4 = tensor.empty() : tensor<1x1024x512xbf16>
    %5 = ""ttir.broadcast""(%3, %4) <{dimension = [1, 2], operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<1024x1xbf16>, tensor<1x1024x512xbf16>) -> tensor<1x1024x512xbf16>
    %6 = tensor.empty() : tensor<1x1024x512xbf16>
    %7 = ""ttir.add""(%1, %5, %6) <{operandSegmentSizes = array<i32: 2, 1>, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<1x1024x512xbf16>, tensor<1x1024x512xbf16>, tensor<1x1024x512xbf16>) -> tensor<1x1024x512xbf16>
    return %7 : tensor<1x1024x512xbf16>
  }
}

Error

results/mlir_tests/ttir/aten::convolution_0.mlir:5:10: error: failed to legalize operation 'ttir.convolution' that was explicitly marked illegal    "results/mlir_tests/ttir/aten::convolution_0.mlir:5:10: error: failed to legalize operation 'ttir.convolution' that was explicitly marked illegal
    %1 = ""ttir.convolution""(%arg0, %arg1, %0) <{batch_group_count = 1 : i64, convolution_layout = #ttir<convolution_layout input_batch = 0, input_feature = 1, input_spatial_dimensions = 2, kernel_output_feature = 0, kernel_input_feature = 1, kernel_spatial_dimensions = 2, output_batch = 0, output_feature = 1, output_spatial_dimensions = 2>, feature_group_count = 1 : i64, input_dilation = array<i64: 1>, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile], padding = array<i64: 0, 0>, weight_dilation = array<i64: 1>, window_reversal = array<i1: false>, window_strides = array<i64: 1>}> : (tensor<1x256x512xbf16>, tensor<1024x256x1xbf16>, tensor<1x1024x512xbf16>) -> tensor<1x1024x512xbf16>
         ^
results/mlir_tests/ttir/aten::convolution_0.mlir:5:10: note: see current operation: %1 = ""ttir.convolution""(%arg0, %arg1, %0) <{batch_group_count = 1 : i64, convolution_layout = #ttir<convolution_layout input_batch = 0, input_feature = 1, input_spatial_dimensions = 2, kernel_output_feature = 0, kernel_input_feature = 1, kernel_spatial_dimensions = 2, output_batch = 0, output_feature = 1, output_spatial_dimensions = 2>, feature_group_count = 1 : i64, input_dilation = array<i64: 1>, operand_constraints = [#tt.operand_constraint<dram|l1|tile|none|interleaved|single_bank|height_sharded|width_sharded|block_sharded|any_layout|any_device_tile>, #tt.operand_constraint<dram|l1|tile|none|interleaved|single_bank|height_sharded|width_sharded|block_sharded|any_layout|any_device_tile>, #tt.operand_constraint<dram|l1|tile|none|interleaved|single_bank|height_sharded|width_sharded|block_sharded|any_layout|any_device_tile>], padding = array<i64: 0, 0>, weight_dilation = array<i64: 1>, window_reversal = array<i1: false>, window_strides = array<i64: 1>}> : (tensor<1x256x512xbf16>, tensor<1024x256x1xbf16>, tensor<1x1024x512xbf16>) -> tensor<1x1024x512xbf16>
ajakovljevicTT commented 4 hours ago

In talks with @LPanosTT, we agreed that it makes sense to reshape the 1d tensors to 2d in order to do the convolution, and then reshape the output back. This is the same as tt-metal currently does for their ttnn.Conv1d on the python side.

In addition, I will open an issue to the tt-metal folks to see if they can provide a c++ api to conv1d, which would make things easier on our side by shifting the reshapes into tt-metal.