tenstorrent / tt-mlir

Tenstorrent MLIR compiler
https://tenstorrent.github.io/tt-mlir/
Apache License 2.0
52 stars 7 forks source link

[Bug] Reduce mean (tt.mean): Issues with non-tile aligned dims and default TILE_LAYOUT setup #499

Open nvukobratTT opened 3 weeks ago

nvukobratTT commented 3 weeks ago

Summary

Issues with reduce mean op when tensor dims aren't tile-dim aligned.

Repro instructions

  1. Remove dim == 2 condition:
    if dim == 2:
    pytest.skip("TTNN: Tensor layout bugs")
  2. Run the following command:
    pytest -svv pybuda/test/mlir/test_ops.py::test_mean

Details

Exact error msg:

Unable to reshape a tensor in TILE_LAYOUT to non-tile height and width! Please convert the tensor to ROW_MAJOR_LAYOUT first.

Lowered TTIR:

#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
#loc = loc("Mean":4294967295:0)
module @"tt-forge-graph" attributes {tt.system_desc = #tt.system_desc<[{arch = <wormhole_b0>, grid = 8x8, l1_size = 1048576, num_dram_channels = 12, dram_channel_size = 1048576, noc_l1_address_align_bytes = 16, pcie_address_align_bytes = 32, noc_dram_address_align_bytes = 32}], [0], [<pcie|host_mmio>], [<0, 0, 0, 0>]>} {
  func.func @forward(%arg0: tensor<1x41x41xf32> {ttir.name = "x"} loc("Mean":4294967295:0)) -> tensor<1x41x1xf32> {
    %0 = tensor.empty() : tensor<1x41x1xf32> loc(#loc3)
    %1 = "ttir.mean"(%arg0, %0) <{dim_arg = [-1 : i32], keep_dim = true, operand_constraints = [#any_device, #any_device]}> : (tensor<1x41x41xf32>, tensor<1x41x1xf32>) -> tensor<1x41x1xf32> loc(#loc3)
    return %1 : tensor<1x41x1xf32> loc(#loc2)
  } loc(#loc)
} loc(#loc)
#loc1 = loc("Mean":4294967295:35)
#loc2 = loc(unknown)
#loc3 = loc("reduce_avg_16"(#loc1))

Trace sample:

#l1_ = #tt.memory_space<l1>
#loc = loc("Mean":4294967295:0)
#system = #tt.memory_space<system>
#layout = #tt.layout<(d0, d1, d2) -> (d0 * 41 + d1, d2), undef, <1x1>, memref<41x41xf32, #system>>
#layout1 = #tt.layout<(d0, d1, d2) -> (d0 * 41 + d1, d2), undef, <8x8>, memref<6x1xf32, #system>>
#layout2 = #tt.layout<(d0, d1, d2) -> (d0 * 41 + d1, d2), undef, <8x8>, memref<6x6xf32, #l1_>>
#layout3 = #tt.layout<(d0, d1, d2) -> (d0 * 41 + d1, d2), undef, <8x8>, memref<6x1xf32, #l1_>>
module @"tt-forge-graph" attributes {tt.device = #tt.device<#tt.grid<8x8, (d0, d1) -> (0, d0, d1)>, [0]>, tt.system_desc = #tt.system_desc<[{arch = <wormhole_b0>, grid = 8x8, l1_size = 1048576, num_dram_channels = 12, dram_channel_size = 1048576, noc_l1_address_align_bytes = 16, pcie_address_align_bytes = 32, noc_dram_address_align_bytes = 32}], [0], [<pcie|host_mmio>], [<0, 0, 0, 0>]>} {
  func.func @forward(%arg0: tensor<1x41x41xf32, #layout> {ttir.name = "x"} loc("Mean":4294967295:0)) -> tensor<1x41x1xf32, #layout1> {
    %0 = "ttnn.open_device"() : () -> !tt.device<<#tt.grid<8x8, (d0, d1) -> (0, d0, d1)>, [0]>> loc(#loc)
    %1 = "ttnn.full"(%0) <{fillValue = 0.000000e+00 : f32}> : (!tt.device<<#tt.grid<8x8, (d0, d1) -> (0, d0, d1)>, [0]>>) -> tensor<1x41x41xf32, #layout2> loc(#loc3)
    %2 = "ttnn.to_memory_config"(%arg0, %1) : (tensor<1x41x41xf32, #layout>, tensor<1x41x41xf32, #layout2>) -> tensor<1x41x41xf32, #layout2> loc(#loc3)
    %3 = "ttnn.full"(%0) <{fillValue = 0.000000e+00 : f32}> : (!tt.device<<#tt.grid<8x8, (d0, d1) -> (0, d0, d1)>, [0]>>) -> tensor<1x41x1xf32, #layout3> loc(#loc3)
    %4 = "ttnn.mean"(%2, %3) <{dim_arg = [-1 : i32], keep_dim = true}> : (tensor<1x41x41xf32, #layout2>, tensor<1x41x1xf32, #layout3>) -> tensor<1x41x1xf32, #layout3> loc(#loc3)
    %5 = "ttnn.full"(%0) <{fillValue = 0.000000e+00 : f32}> : (!tt.device<<#tt.grid<8x8, (d0, d1) -> (0, d0, d1)>, [0]>>) -> tensor<1x41x1xf32, #layout1> loc(#loc2)
    %6 = "ttnn.to_memory_config"(%4, %5) : (tensor<1x41x1xf32, #layout3>, tensor<1x41x1xf32, #layout1>) -> tensor<1x41x1xf32, #layout1> loc(#loc2)
    "ttnn.close_device"(%0) : (!tt.device<<#tt.grid<8x8, (d0, d1) -> (0, d0, d1)>, [0]>>) -> () loc(#loc)
    return %6 : tensor<1x41x1xf32, #layout1> loc(#loc2)
  } loc(#loc)
} loc(#loc)
#loc1 = loc("Mean":4294967295:35)
#loc2 = loc(unknown)
#loc3 = loc("reduce_avg_16"(#loc1))
sdjordjevicTT commented 2 weeks ago

I looked at this, this is the bug in the the TTNN implementation of reduce op.

In the TTNN implementation of reduction op, there is a code that calculates the output tensor shape:

    std::vector<uint32_t> output_shape;
    std::vector<uint32_t> padded_output_shape;
    for (int axis = 0; axis < input_shape.size(); axis++) {
        if (std::find(dim.begin(), dim.end(), axis) != dim.end()) {
            if (keepdim) {
                output_shape.push_back(1);
                padded_output_shape.push_back(axis >= rank - 2 ? ttnn::TILE_SIZE : 1);
            }
        } else {
            output_shape.push_back(input_shape[axis]);
            padded_output_shape.push_back(input_shape[axis]);
        }
    }

The issue is in this line: padded_output_shape.push_back(input_shape[axis]);

We iterate through all the dims of the tensor. When we encounter the dim that we are not reducing, we want to capture the non-padded and padded parts of the input shape, but we don't capture the padded part correctly because input_shape[axis] returns the unpadded part. Instead, we should retrieve something like this: padded_output_shape.push_back(input_shape.value[axis]);

Later, as a part of the reduction op, the reshape op is executed:

output_tensor = ttnn::reshape(output_tensor, ttnn::Shape{tt::tt_metal::Shape{output_shape, padded_output_shape}});

This reshape operation fails because it expects a padded shape for a tiled layout, but instead, the reduction operation supplies the unpadded shape, hence the following failure raises:

Unable to reshape a tensor in TILE_LAYOUT to non-tile height and width! Please convert the tensor to ROW_MAJOR_LAYOUT first.

I managed to reproduce this easily on plain TTNN:

import torch
import ttnn
import time

device_id = 0
device = ttnn.open_device(device_id=device_id)

torch_tensor = torch.rand(1, 41, 41)
ttnn_tensor = ttnn.from_torch(torch_tensor, layout=ttnn.TILE_LAYOUT, device=device)

ttnn_tensor = ttnn.mean(ttnn_tensor, -1)
ttnn.close_device(device)

Created a following issue on Metal side to track the issue: https://github.com/tenstorrent/tt-metal/issues/12101

nvukobratTT commented 2 weeks ago

I see, thanks for the details @sdjordjevicTT!

As we're not sure when we'll get this one fixed, is it possible to have a workaround for it on the MLIR level? E.g. if reduce op supports ROW_MAJOR_LAYOUT does enforcing helps out? Or, can we push for explicit pad op that we'll help out here? E.g. to decompose reduce into pad + reduce + unpad set of opsops?

Let me know your thoughts! :))

sdjordjevicTT commented 2 weeks ago

I would like to understand the process of fixing this on the TTNN side before attempting any quick fixes. We need to investigate further what can be quickly fixed. You gave great ideas, and we can try those as quick workarounds.

nvukobratTT commented 2 weeks ago

I would like to understand the process of fixing this on the TTNN side before attempting any quick fixes. We need to investigate further what can be quickly fixed. You gave great ideas, and we can try those as quick workarounds.

Thanks for pushing this further @sdjordjevicTT! Let me know if you need any more details from my end on this one :)))