tenstorrent / tt-mlir

Tenstorrent MLIR compiler
https://tenstorrent.github.io/tt-mlir/
Apache License 2.0
59 stars 7 forks source link

TTNN IR tensor LayoutAttr has element type as scalar (TTNN ROW_MAJOR ) for every op #822

Open odjuricicTT opened 4 days ago

odjuricicTT commented 4 days ago

There seem to be inconsistencies in TTNN IR op layout inputs and outputs. Currently all op inputs are set to TILE layout, while all outputs are ROW_MAJOR. Example:

#device = #tt.device<workerGrid = #tt.grid<8x8, (d0, d1) -> (0, d0, d1)>, l1Map = (d0, d1)[s0, s1] -> (0, d0 floordiv s0, d1 floordiv s1, (d0 mod s0) * s1 + d1 mod s1), dramMap = (d0, d1)[s0, s1] -> (0, 0, ((((d0 floordiv s0) * 8 + d1 floordiv s1) * (s1 * s0) + (d0 mod s0) * s1 + d1 mod s1) floordiv 8192) mod 12, (((d0 floordiv s0) * 8 + d1 floordiv s1) * (s1 * s0) + (d0 mod s0) * s1 + d1 mod s1) floordiv 98304 + (((d0 floordiv s0) * 8 + d1 floordiv s1) * (s1 * s0) + (d0 mod s0) * s1 + d1 mod s1) mod 8192), meshShape = , chipIds = [0]>
#dram = #tt.memory_space<dram>
#l1_ = #tt.memory_space<l1>
#system = #tt.memory_space<system>
#system_desc = #tt.system_desc<[{arch = <wormhole_b0>, grid = 8x8, l1_size = 1499136, num_dram_channels = 12, dram_channel_size = 1073741824, noc_l1_address_align_bytes = 16, pcie_address_align_bytes = 32, noc_dram_address_align_bytes = 32, l1_unreserved_base = 107616, erisc_l1_unreserved_base = 107616, dram_unreserved_base = 32, dram_unreserved_end = 1073096128, physical_cores = {worker = [ 2x1,  2x2,  2x3,  2x4,  2x6,  2x7,  2x8,  2x9,  3x1,  3x2,  3x3,  3x4,  3x6,  3x7,  3x8,  3x9,  4x1,  4x2,  4x3,  4x4,  4x6,  4x7,  4x8,  4x9,  5x1,  5x2,  5x3,  5x4,  5x6,  5x7,  5x8,  5x9,  7x1,  7x2,  7x3,  7x4,  7x6,  7x7,  7x8,  7x9,  8x1,  8x2,  8x3,  8x4,  8x6,  8x7,  8x8,  8x9,  9x1,  9x2,  9x3,  9x4,  9x6,  9x7,  9x8,  9x9,  10x1,  10x2,  10x3,  10x4,  10x6,  10x7,  10x8,  10x9] dram = [ 1x0,  1x5,  2x5,  3x5,  5x0,  5x5,  7x0,  7x5,  8x5,  9x5,  11x0,  11x5] eth_inactive = [ 0x1,  0x2,  0x3,  0x4,  0x6,  0x7,  0x8,  0x9,  6x2,  6x3,  6x6,  6x7,  6x8]}, supported_data_types = [<f32>, <f16>, <bf16>, <bfp_f8>, <bfp_bf8>, <bfp_f4>, <bfp_bf4>, <bfp_f2>, <bfp_bf2>, <u32>, <u16>, <u8>], supported_tile_sizes = [ 4x16,  16x16,  32x16,  4x32,  16x32,  32x32]}], [0], [3 : i32], [ 0x0x0x0]>
#layout = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<1x784xf32, #system>>
#layout1 = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<1x10xf32, #system>>
#layout2 = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<256x10xf32, #system>>
#layout3 = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<1x256xf32, #system>>
#layout4 = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<784x256xf32, #system>>
#layout5 = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x8>, memref<784x32xf32, #l1_>, height_sharded>
#layout6 = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x8>, memref<1x2xf32, #l1_>, height_sharded>
#layout7 = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x8>, memref<1x98xf32, #l1_>, height_sharded>
#layout8 = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x8>, memref<1x32xf32, #l1_>, height_sharded>
#layout9 = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x8>, memref<256x2xf32, #l1_>, height_sharded>
#layout10 = #tt.layout<(d0, d1) -> (d0, d1), undef, <8x8>, memref<1x2xf32, #dram>, interleaved>
module @"tt-forge-graph" attributes {tt.device = #device, tt.system_desc = #system_desc} {
  func.func @main(%arg0: tensor<1x784xf32, #layout>, %arg1: tensor<1x10xf32, #layout1>, %arg2: tensor<256x10xf32, #layout2>, %arg3: tensor<1x256xf32, #layout3>, %arg4: tensor<784x256xf32, #layout4>) -> tensor<1x10xf32, #layout1> {
    %0 = "ttnn.get_device"() <{mesh_shape = #ttnn<mesh_shape 1x1>}> : () -> !tt.device<#device>
    %1 = "ttnn.to_layout"(%arg4, %0) <{layout = #ttnn.layout<tile>}> : (tensor<784x256xf32, #layout4>, !tt.device<#device>) -> tensor<784x256xf32, #layout5>
    %2 = "ttnn.to_device"(%1, %0) <{memory_config = #ttnn.memory_config<<height_sharded>, <l1>>}> : (tensor<784x256xf32, #layout5>, !tt.device<#device>) -> tensor<784x256xf32, #layout5>
    %3 = "ttnn.to_layout"(%arg1, %0) <{layout = #ttnn.layout<tile>}> : (tensor<1x10xf32, #layout1>, !tt.device<#device>) -> tensor<1x10xf32, #layout6>
    %4 = "ttnn.to_device"(%3, %0) <{memory_config = #ttnn.memory_config<<height_sharded>, <l1>>}> : (tensor<1x10xf32, #layout6>, !tt.device<#device>) -> tensor<1x10xf32, #layout6>
    %5 = "ttnn.to_layout"(%arg0, %0) <{layout = #ttnn.layout<tile>}> : (tensor<1x784xf32, #layout>, !tt.device<#device>) -> tensor<1x784xf32, #layout7>
    %6 = "ttnn.to_device"(%5, %0) <{memory_config = #ttnn.memory_config<<height_sharded>, <l1>>}> : (tensor<1x784xf32, #layout7>, !tt.device<#device>) -> tensor<1x784xf32, #layout7>
    %7 = "ttnn.to_layout"(%arg3, %0) <{layout = #ttnn.layout<tile>}> : (tensor<1x256xf32, #layout3>, !tt.device<#device>) -> tensor<1x256xf32, #layout8>
    %8 = "ttnn.to_device"(%7, %0) <{memory_config = #ttnn.memory_config<<height_sharded>, <l1>>}> : (tensor<1x256xf32, #layout8>, !tt.device<#device>) -> tensor<1x256xf32, #layout8>
    %9 = "ttnn.to_layout"(%arg2, %0) <{layout = #ttnn.layout<tile>}> : (tensor<256x10xf32, #layout2>, !tt.device<#device>) -> tensor<256x10xf32, #layout9>
    %10 = "ttnn.to_device"(%9, %0) <{memory_config = #ttnn.memory_config<<height_sharded>, <l1>>}> : (tensor<256x10xf32, #layout9>, !tt.device<#device>) -> tensor<256x10xf32, #layout9>
    %11 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes<f32>, layout = #ttnn.layout<row_major>, memory_config = #ttnn.memory_config<<height_sharded>, <l1>>, shape = #ttnn.shape<1x256>}> : (!tt.device<#device>) -> tensor<1x256xf32, #layout8>
    %12 = "ttnn.matmul"(%6, %2, %11) : (tensor<1x784xf32, #layout7>, tensor<784x256xf32, #layout5>, tensor<1x256xf32, #layout8>) -> tensor<1x256xf32, #layout8>
    %13 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes<f32>, layout = #ttnn.layout<row_major>, memory_config = #ttnn.memory_config<<height_sharded>, <l1>>, shape = #ttnn.shape<1x256>}> : (!tt.device<#device>) -> tensor<1x256xf32, #layout8>
    %14 = "ttnn.add"(%12, %8, %13) <{operandSegmentSizes = array<i32: 2, 1>}> : (tensor<1x256xf32, #layout8>, tensor<1x256xf32, #layout8>, tensor<1x256xf32, #layout8>) -> tensor<1x256xf32, #layout8>
    %15 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes<f32>, layout = #ttnn.layout<row_major>, memory_config = #ttnn.memory_config<<height_sharded>, <l1>>, shape = #ttnn.shape<1x256>}> : (!tt.device<#device>) -> tensor<1x256xf32, #layout8>
    %16 = "ttnn.relu"(%14, %15) <{operandSegmentSizes = array<i32: 1, 1>}> : (tensor<1x256xf32, #layout8>, tensor<1x256xf32, #layout8>) -> tensor<1x256xf32, #layout8>
    %17 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes<f32>, layout = #ttnn.layout<row_major>, memory_config = #ttnn.memory_config<<height_sharded>, <l1>>, shape = #ttnn.shape<1x10>}> : (!tt.device<#device>) -> tensor<1x10xf32, #layout6>
    %18 = "ttnn.matmul"(%16, %10, %17) : (tensor<1x256xf32, #layout8>, tensor<256x10xf32, #layout9>, tensor<1x10xf32, #layout6>) -> tensor<1x10xf32, #layout6>
    %19 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes<f32>, layout = #ttnn.layout<row_major>, memory_config = #ttnn.memory_config<<height_sharded>, <l1>>, shape = #ttnn.shape<1x10>}> : (!tt.device<#device>) -> tensor<1x10xf32, #layout6>
    %20 = "ttnn.add"(%18, %4, %19) <{operandSegmentSizes = array<i32: 2, 1>}> : (tensor<1x10xf32, #layout6>, tensor<1x10xf32, #layout6>, tensor<1x10xf32, #layout6>) -> tensor<1x10xf32, #layout6>
    %21 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes<f32>, layout = #ttnn.layout<row_major>, memory_config = #ttnn.memory_config<<interleaved>, <dram>>, shape = #ttnn.shape<1x10>}> : (!tt.device<#device>) -> tensor<1x10xf32, #layout10>
    %22 = "ttnn.softmax"(%20, %21) <{dimension = 1 : si32}> : (tensor<1x10xf32, #layout6>, tensor<1x10xf32, #layout10>) -> tensor<1x10xf32, #layout10>
    %23 = "ttnn.to_memory_config"(%22, %0) : (tensor<1x10xf32, #layout10>, !tt.device<#device>) -> tensor<1x10xf32, #layout1>
    return %23 : tensor<1x10xf32, #layout1>
  }
}

All input to_layout ops convert to TILE layout, while all op outputs are ROW_MAJOR.

@sdjordjevicTT Opening this issue for discussion.

odjuricicTT commented 4 days ago

This seems to be preventing us from properly sharding op outputs. We end up with TT-metal asserting that shard shape is not tilized.

sdjordjevicTT commented 3 days ago

Hey @odjuricicTT, please take a look at this discussion: https://github.com/tenstorrent/tt-mlir/issues/272

Currently, we don't have a way to model layout properly, hence I am trying to understand what can be the default behavior for the compiler.