tenstorrent / tt-mlir

Tenstorrent MLIR compiler
https://tenstorrent.github.io/tt-mlir/
Apache License 2.0
77 stars 13 forks source link

ttir.reshape failure due to broadcast op #1345

Open mmanzoorTT opened 5 days ago

mmanzoorTT commented 5 days ago

The example below reshape 6x2 tensor to 2400x2 tensor with multiple reshape ops and a broadcast op. We are folding broadcast op in TTIR->TTNN conversion, so third reshape op will consume second reshape op without using broadcast causing failure due to mismatched number of elements.

module {
  func.func @main(%arg0: tensor<6x2xf32>) -> tensor<2400x2xf32> {
    %0 = stablehlo.reshape %arg0 : (tensor<6x2xf32>) -> tensor<1x6x2xf32>
    %1 = stablehlo.reshape %0 : (tensor<1x6x2xf32>) -> tensor<1x6x1x2xf32>
    %2 = stablehlo.broadcast_in_dim %1, dims = [0, 1, 2, 3] : (tensor<1x6x1x2xf32>) -> tensor<400x6x1x2xf32>
    %3 = stablehlo.reshape %2 : (tensor<400x6x1x2xf32>) -> tensor<2400x1x2xf32>
    %4 = stablehlo.reshape %3 : (tensor<2400x1x2xf32>) -> tensor<2400x2xf32>
    return %4 : tensor<2400x2xf32>
  }
}

TTIR Graph

"#any_device_tile = #tt.operand_constraint<dram|l1|tile|none|interleaved|single_bank|height_sharded|width_sharded|block_sharded|any_layout|any_device_tile>
module {
  func.func @main(%arg0: tensor<6x2xf32>) -> tensor<2400x2xf32> {
    %0 = tensor.empty() : tensor<1x6x2xf32>
    %1 = ""ttir.reshape""(%arg0, %0) <{operand_constraints = [#any_device_tile, #any_device_tile], shape = [1 : i32, 6 : i32, 2 : i32]}> : (tensor<6x2xf32>, tensor<1x6x2xf32>) -> tensor<1x6x2xf32>
    %2 = tensor.empty() : tensor<1x6x1x2xf32>
    %3 = ""ttir.reshape""(%1, %2) <{operand_constraints = [#any_device_tile, #any_device_tile], shape = [1 : i32, 6 : i32, 1 : i32, 2 : i32]}> : (tensor<1x6x2xf32>, tensor<1x6x1x2xf32>) -> tensor<1x6x1x2xf32>
    %4 = tensor.empty() : tensor<400x6x1x2xf32>
    %5 = ""ttir.broadcast""(%3, %4) <{dimension = [0, 1, 2, 3], operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<1x6x1x2xf32>, tensor<400x6x1x2xf32>) -> tensor<400x6x1x2xf32>
    %6 = tensor.empty() : tensor<2400x1x2xf32>
    %7 = ""ttir.reshape""(%5, %6) <{operand_constraints = [#any_device_tile, #any_device_tile], shape = [2400 : i32, 1 : i32, 2 : i32]}> : (tensor<400x6x1x2xf32>, tensor<2400x1x2xf32>) -> tensor<2400x1x2xf32>
    %8 = tensor.empty() : tensor<2400x2xf32>
    %9 = ""ttir.reshape""(%7, %8) <{operand_constraints = [#any_device_tile, #any_device_tile], shape = [2400 : i32, 2 : i32]}> : (tensor<2400x1x2xf32>, tensor<2400x2xf32>) -> tensor<2400x2xf32>
    return %9 : tensor<2400x2xf32>
  }
}

Error message

results/mlir_tests/ttir/aten::repeat_0.mlir:11:10: error: 'ttnn.reshape' op Input and output tensors must have the same number of elements  "results/mlir_tests/ttir/aten::repeat_0.mlir:11:10: error: 'ttnn.reshape' op Input and output tensors must have the same number of elements
    %7 = ""ttir.reshape""(%5, %6) <{operand_constraints = [#any_device_tile, #any_device_tile], shape = [2400 : i32, 1 : i32, 2 : i32]}> : (tensor<400x6x1x2xf32>, tensor<2400x1x2xf32>) -> tensor<2400x1x2xf32>
         ^
results/mlir_tests/ttir/aten::repeat_0.mlir:11:10: note: see current operation: %8 = ""ttnn.reshape""(%5) <{shape = [2400 : i32, 1 : i32, 2 : i32]}> : (tensor<1x6x1x2xf32, #ttnn.ttnn_layout<(d0, d1, d2, d3) -> (d0 * 6 + d1 + d2, d3), <1x1>, memref<1x1x!tt.tile<32x32, f32>, #ttnn.buffer_type<dram>>, interleaved>>) -> tensor<2400x1x2xf32, #ttnn.ttnn_layout<(d0, d1, d2) -> (d0 + d1, d2), <1x1>, memref<75x1x!tt.tile<32x32, f32>, #ttnn.buffer_type<dram>>, interleaved>>