tenstorrent / tt-mlir

Tenstorrent MLIR compiler
https://tenstorrent.github.io/tt-mlir/
Apache License 2.0
66 stars 8 forks source link

ttnn.concat fails when the concat dimension of the input tensors are not TILE aligned #795

Open chandrasekaranpradeep opened 3 weeks ago

chandrasekaranpradeep commented 3 weeks ago

Summary: In Llama 3B rotary embedding, the ttnn.concat fails when the concat dimension of the input tensors are not TILE aligned.

For more context, here is the exact error message:

 E       RuntimeError: TT_FATAL @ /proj_sw/user_dev/pchandrasekaran/Forge/tt-forge-fe/third_party/tt-mlir/third_party/tt-metal/src/tt-metal/ttnn/cpp/ttnn/operations/data_movement/concat/device/concat_device_operation.cpp:62: !in_ref.get_shape().has_tile_padding(this->dim)
E       info:
E       Tile padding along concatenated dim (3) not supported for concat yet (tensor: 0).

Details: The tuple of two input tensor of same shape (1, 32, 12, 50) are passed to ttnn.concat op with concat dim = -1, while validating the input tensor of the concat op in the ttnn::operations::data_movement::ConcatDeviceOperation::validate function in TTNN, the Tile padding along concatenated dim (3) not supported for concat yet (tensor: 0) error is thrown.

The ttnn.concat op expects the concatenated dim of the input tensors should be tile aligned because tile padding along the concat dim of the input tensors are not supported in ttnn.concat op.

Repro:

Concat TTIR:

module @Concat attributes {tt.system_desc = #tt.system_desc<[{arch = <wormhole_b0>, grid = 8x8, l1_size = 1499136, num_dram_channels = 12, dram_channel_size = 1073741824, noc_l1_address_align_bytes = 16, pcie_address_align_bytes = 32, noc_dram_address_align_bytes = 32, l1_unreserved_base = 1024, erisc_l1_unreserved_base = 1024, dram_unreserved_base = 1024, dram_unreserved_end = 1073741824, physical_cores = {worker = [(0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (0, 5), (0, 6), (0, 7), (1, 0), (1, 1), (1, 2), (1, 3), (1, 4), (1, 5), (1, 6), (1, 7), (2, 0), (2, 1), (2, 2), (2, 3), (2, 4), (2, 5), (2, 6), (2, 7), (3, 0), (3, 1), (3, 2), (3, 3), (3, 4), (3, 5), (3, 6), (3, 7), (4, 0), (4, 1), (4, 2), (4, 3), (4, 4), (4, 5), (4, 6), (4, 7), (5, 0), (5, 1), (5, 2), (5, 3), (5, 4), (5, 5), (5, 6), (5, 7), (6, 0), (6, 1), (6, 2), (6, 3), (6, 4), (6, 5), (6, 6), (6, 7), (7, 0), (7, 1), (7, 2), (7, 3), (7, 4), (7, 5), (7, 6), (7, 7)] dram = [(8, 0), (9, 0), (10, 0), (8, 1), (9, 1), (10, 1), (8, 2), (9, 2), (10, 2), (8, 3), (9, 3), (10, 3)]}, supported_data_types = [<f32>, <f16>, <bf16>, <bfp_f8>, <bfp_bf8>, <bfp_f4>, <bfp_bf4>, <bfp_f2>, <bfp_bf2>, <u32>, <u16>, <u8>], supported_tile_sizes = [(4 x 16), (16 x 16), (32 x 16), (4 x 32), (16 x 32), (32 x 32)]}], [0], [3 : i32], [<0, 0, 0, 0>]>} {
  func.func @forward(%arg0: tensor<1x32x12x50xf32> {ttir.name = "a"}, %arg1: tensor<1x32x12x50xf32> {ttir.name = "b"}) -> (tensor<1x32x12x100xf32> {ttir.name = "Concat.output_concatenate_0"}) {
    %0 = tensor.empty() : tensor<1x32x12x100xf32>
    %1 = "ttir.concat"(%arg0, %arg1, %0) <{dim = -1 : si32, operand_constraints = [#tt.operand_constraint<dram|l1|scalar|tile|none|interleaved|single_bank|height_sharded|width_sharded|block_sharded|any_layout|any_device|any_device_tile|l1_block_sharded>, #tt.operand_constraint<dram|l1|scalar|tile|none|interleaved|single_bank|height_sharded|width_sharded|block_sharded|any_layout|any_device|any_device_tile|l1_block_sharded>, #tt.operand_constraint<dram|l1|scalar|tile|none|interleaved|single_bank|height_sharded|width_sharded|block_sharded|any_layout|any_device|any_device_tile|l1_block_sharded>]}> : (tensor<1x32x12x50xf32>, tensor<1x32x12x50xf32>, tensor<1x32x12x100xf32>) -> tensor<1x32x12x100xf32>
    return %1 : tensor<1x32x12x100xf32>
  }
}

TTNN test cases:

import torch
import ttnn
from tests.ttnn.utils_for_testing import assert_with_pcc

def test_concat_rotary_embedding(device):
    concat_dim = 3
    torch_input_tensor_a = torch.rand((1, 32, 12, 50), dtype=torch.float32)
    torch_input_tensor_b = torch.rand((1, 32, 12, 50), dtype=torch.float32)
    torch_output_tensor = torch.concat([torch_input_tensor_a, torch_input_tensor_b], dim=concat_dim)

    input_tensor_a = ttnn.from_torch(torch_input_tensor_a, layout=ttnn.TILE_LAYOUT, device=device)
    input_tensor_b = ttnn.from_torch(torch_input_tensor_b, layout=ttnn.TILE_LAYOUT, device=device)

    output = ttnn.concat([input_tensor_a, input_tensor_b], dim=concat_dim)
    output = ttnn.to_torch(output)

    assert_with_pcc(torch_output_tensor, output, 0.9999)

TT-Forge-fe Concat test:

git checkout pchandrasekaran/concat
pytest forge/test/mlir/test_ops.py::test_concat_rotary_embedding -vss
nvukobratTT commented 1 week ago

@chandrasekaranpradeep do we have a matching issue on TTNN/Metal repo as well?

chandrasekaranpradeep commented 1 week ago

@nvukobratTT I have not created a issue on TT-Metal repo. Do I need to create a issue for this on TT-metal repo?

nvukobratTT commented 1 week ago

@nvukobratTT I have not created a issue on TT-Metal repo. Do I need to create a issue for this on TT-metal repo?

As we already have TTNN repro, let's also create issue on TT-Metal as well :))

chandrasekaranpradeep commented 1 week ago

Sure I will create a issue on TT-Metal

chandrasekaranpradeep commented 1 week ago

@nvukobratTT Created a issue in TT-Metal repo - https://github.com/tenstorrent/tt-metal/issues/13667. The concat op with these configuration (i.e input_shape, dim) is present in the Llama 3B rotary embedding. Can we change the priority for the issue?