tenstorrent / tt-metal

:metal: TT-NN operator library, and TT-Metalium low level kernel programming model.
Apache License 2.0
396 stars 49 forks source link

[Bug Report] Reduce mean (ttnn.mean): Issues with non-tile aligned dims and TILE_LAYOUT mode #12101

Closed sdjordjevicTT closed 4 days ago

sdjordjevicTT commented 2 weeks ago

Describe the bug While developing the TT-MLIR compiler, we encountered an issue when using ttnn.mean op. The problem occurs with ttnn.mean operation when input tensor dims aren't tile-dim aligned. The code errors out with the following error message:

Unable to reshape a tensor in TILE_LAYOUT to non-tile height and width! Please convert the tensor to ROW_MAJOR_LAYOUT first.

To Reproduce I managed to reproduce this easily on plain TTNN:

import torch
import ttnn
import time

device_id = 0
device = ttnn.open_device(device_id=device_id)

torch_tensor = torch.rand(1, 41, 41)
ttnn_tensor = ttnn.from_torch(torch_tensor, layout=ttnn.TILE_LAYOUT, device=device)

ttnn_tensor = ttnn.mean(ttnn_tensor, -1)
ttnn.close_device(device)

The output of this test:

Always | FATAL | Unable to reshape a tensor in TILE_LAYOUT to non-tile height and width! Please convert the tensor to ROW_MAJOR_LAYOUT first. Traceback (most recent call last): File "/localdev/sdjordjevic/src/tt-metal/ttnn/examples/usage/program_cache.py", line 16, in ttnn_tensor = ttnn.mean(ttnn_tensor, -1) File "/localdev/sdjordjevic/src/tt-metal/ttnn/ttnn/decorators.py", line 327, in call return self.function(*function_args, **function_kwargs) RuntimeError: TT_THROW @ /localdev/sdjordjevic/src/tt-metal/ttnn/cpp/ttnn/operations/core/core.cpp:49: tt::exception info: Unable to reshape a tensor in TILE_LAYOUT to non-tile height and width! Please convert the tensor to ROW_MAJOR_LAYOUT first. backtrace: --- ttnn::operations::core::reshape(tt::tt_metal::Tensor const&, ttnn::types::Shape const&) --- /localdev/sdjordjevic/src/tt-metal/ttnn/ttnn/_ttnn.so(+0xd7c3a3) [0x7f19e5bc13a3] --- /localdev/sdjordjevic/src/tt-metal/ttnn/ttnn/_ttnn.so(_ZN4ttnn10operations9reduction6ReduceILNS1_10ReduceTypeE1EE6invokeERKN2tt8tt_metal6TensorERKNSt3__18optionalINSA_7variantIJiNSA_6vectorIiNSA_9allocatorIiEEEEEEEEEbRKNSB_INS6_12MemoryConfigEEERKNSB_INSC_IJNS_28GrayskullComputeKernelConfigENS_27WormholeComputeKernelConfigEEEEEEf+0x14) [0x7f19e5bcb824]

Expected behavior The expected behavior is that the ttnn.mean doesn't produce errors when a non-tiled align tensor is provided as its operand.

Screenshots If applicable, add screenshots to help explain your problem.

Please complete the following environment information:

Additional context I looked a bit at the issue. In the TTNN implementation of reduction op, there is a code that calculates the output tensor shape:

    std::vector<uint32_t> output_shape;
    std::vector<uint32_t> padded_output_shape;
    for (int axis = 0; axis < input_shape.size(); axis++) {
        if (std::find(dim.begin(), dim.end(), axis) != dim.end()) {
            if (keepdim) {
                output_shape.push_back(1);
                padded_output_shape.push_back(axis >= rank - 2 ? ttnn::TILE_SIZE : 1);
            }
        } else {
            output_shape.push_back(input_shape[axis]);
            padded_output_shape.push_back(input_shape[axis]);
        }
    }

The issue is in this line: padded_output_shape.push_back(input_shape[axis]);

We iterate through all the dims of the tensor. When we encounter the dim that we are not reducing, we want to capture the non-padded and padded parts of the input shape, but we don't capture the padded part correctly because input_shape[axis] returns the unpadded part. Instead, we should retrieve something like this: padded_output_shape.push_back(input_shape.value[axis]);

Later, as a part of the reduction op, the reshape op is executed:

output_tensor = ttnn::reshape(output_tensor, ttnn::Shape{tt::tt_metal::Shape{output_shape, padded_output_shape}});

This reshape operation fails because it expects a padded shape for a tiled layout, but instead, the reduction operation supplies the unpadded shape, hence the failure occurs.

sdjordjevicTT commented 2 weeks ago

@tarafdarTT @sjameelTT It seems that the reduction operation only works when the dimensions of the input tensors are tile-aligned.

sdjordjevicTT commented 1 week ago

@tarafdarTT @sjameelTT I created the following PR to fix this: https://github.com/tenstorrent/tt-metal/pull/12274

Can you please review it? :)

sdjordjevicTT commented 4 days ago

Fixed with: https://github.com/tenstorrent/tt-metal/pull/12274