tenstorrent / tt-metal

:metal: TT-NN operator library, and TT-Metalium low level kernel programming model.
https://docs.tenstorrent.com/ttnn/latest/index.html
Apache License 2.0
488 stars 80 forks source link

[Bug Report] ttnn.max - Fails with tensor rank is not 4 #13190

Open chandrasekaranpradeep opened 2 months ago

chandrasekaranpradeep commented 2 months ago

Describe the bug ttnn.max op throws Tensor rank is not 4 error when reducing 3d input tensor along batch dimension. Additional Note: When testing the ttnn.max op with 3d input tensor of shape (2, 32, 64) and dim = 0, it goes into if (dim[0] == rank - 3) in reduce_impl in the ttnn/cpp/ttnn/operations/reduction/generic/generic_reductions.cpp file but looks like the condition is specifically made for rank = 4.

For more context, here is the exact error message

E       RuntimeError: TT_FATAL @ ../ttnn/cpp/ttnn/operations/data_movement/pad/pad.cpp:41: rank == 4
E       info:
E       Tensor rank is not 4

To Reproduce Run the following test:

import pytest
import torch
import ttnn
from tests.ttnn.utils_for_testing import assert_with_pcc
from models.utility_functions import torch_random

@pytest.mark.parametrize("input_shape_and_dim", 
                         [
                             ((2, 32, 64), 0),
                             ((2, 22, 37), 0),
                         ])
def test_max_rank_issue(device, input_shape_and_dim):

    input_shape, max_dim = input_shape_and_dim

    torch_input_tensor = torch_random(input_shape, -100, 100, dtype=torch.bfloat16)
    torch_output_tensor, _ = torch.max(torch_input_tensor, dim=max_dim, keepdim=True)

    input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device)

    output_tensor = ttnn.max(input_tensor, dim=max_dim)
    output_tensor = ttnn.to_layout(output_tensor, ttnn.TILE_LAYOUT)
    output_tensor = ttnn.from_device(output_tensor)

    output_tensor = ttnn.to_torch(output_tensor)

    assert_with_pcc(torch_output_tensor, output_tensor)

Expected behavior It reduces along the batch dimension for the 3d input tensor.

prajaramanTT commented 3 days ago

Keeping Priority label in sync with Priority field - P2