tenstorrent / tt-metal

:metal: TT-NN operator library, and TT-Metalium low level kernel programming model.
https://docs.tenstorrent.com/ttnn/latest/index.html
Apache License 2.0
504 stars 83 forks source link

[Bug Report] ttnn.reshape produces mismatching results when input is TILE vs ROW_MAJOR #15524

Open kevinwuTT opened 1 week ago

kevinwuTT commented 1 week ago

Describe the bug When calling ttnn.reshape for certain input and output sizes, the output values are different depending on whether the input was tilized or row major.

To Reproduce

import ttnn
import torch

input_shapes = ((1, 56, 56, 64), (1, 28, 28, 16))
output_shapes = ((1, 1, 3136, 64), (1, 1, 784, 16))

def test(input_shape, output_shape):
    arg0_1 = torch.rand(input_shape, dtype = torch.bfloat16)

    with ttnn.manage_device(device_id=0) as device:   
        ttnn_from_torch_1 = ttnn.from_torch(arg0_1, layout = ttnn.TILE_LAYOUT, device = device, dtype = ttnn.bfloat16)      
        ttnn_reshape_1 = ttnn.reshape(ttnn_from_torch_1, output_shape)
        ttnn_reshape_1 = ttnn.to_torch(ttnn_reshape_1)

        ttnn_from_torch_2 = ttnn.from_torch(arg0_1, layout = ttnn.ROW_MAJOR_LAYOUT, dtype = ttnn.bfloat16)
        ttnn_reshape_2 = ttnn.reshape(ttnn_from_torch_2, output_shape)
        ttnn_reshape_2 = ttnn.to_torch(ttnn_reshape_2)

    print(f"input_shape: {input_shape}, output_shape: {output_shape}")
    print("TILE_LAYOUT on device:")
    print(ttnn_reshape_1)
    print("ROW_MAJOR_LAYOUT:")
    print(ttnn_reshape_2)
    print("torch.allclose:", torch.allclose(ttnn_reshape_1, ttnn_reshape_2))

for input_shape, output_shape in zip(input_shapes, output_shapes):
    test(input_shape, output_shape)

Expected behavior Calling ttnn.reshape with input shape (1, 56, 56, 64) and output shape (1, 1, 3136, 64) does not produce matching outputs when input tensor is in TILE_LAYOUT vs ROW_MAJOR_LAYOUT. Notice the trailing 0s in the output tensor for the TILE_LAYOUT process.

input_shape: (1, 56, 56, 64), output_shape: (1, 1, 3136, 64)
TILE_LAYOUT on device:
TorchTensor([[[[0.6328, 0.6836, 0.0312,  ..., 0.0234, 0.4258, 0.5859],
               [0.6406, 0.0664, 0.2109,  ..., 0.3047, 0.8828, 0.4766],
               [0.6133, 0.4844, 0.1523,  ..., 0.8047, 0.7422, 0.2578],
               ...,
               [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
               [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
               [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]]]],
            dtype=torch.bfloat16)
ROW_MAJOR_LAYOUT:
TorchTensor([[[[0.6328, 0.6836, 0.0312,  ..., 0.0234, 0.4258, 0.5859],
               [0.6406, 0.0664, 0.2109,  ..., 0.3047, 0.8828, 0.4766],
               [0.6133, 0.4844, 0.1523,  ..., 0.8047, 0.7422, 0.2578],
               ...,
               [0.9219, 0.4570, 0.9961,  ..., 0.3203, 0.4336, 0.9961],
               [0.8984, 0.8828, 0.5547,  ..., 0.4219, 0.3672, 0.6172],
               [0.8438, 0.1406, 0.1484,  ..., 0.1094, 0.2539, 0.5078]]]],
            dtype=torch.bfloat16)
torch.allclose: False

For the following input and output shapes, the output values match as expected.

input_shape: (1, 28, 28, 16), output_shape: (1, 1, 784, 16)
TILE_LAYOUT on device:
TorchTensor([[[[0.5859, 0.9492, 0.6172,  ..., 0.0078, 0.5078, 0.2812],
               [0.2539, 0.6992, 0.4453,  ..., 0.1172, 0.9648, 0.1094],
               [0.9883, 0.1914, 0.3633,  ..., 0.2148, 0.2266, 0.2930],
               ...,
               [0.5156, 0.8477, 0.5117,  ..., 0.0039, 0.1094, 0.6328],
               [0.5469, 0.9375, 0.4688,  ..., 0.8750, 0.9492, 0.0586],
               [0.4727, 0.8945, 0.5586,  ..., 0.3125, 0.9688, 0.6250]]]],
            dtype=torch.bfloat16)
ROW_MAJOR_LAYOUT:
TorchTensor([[[[0.5859, 0.9492, 0.6172,  ..., 0.0078, 0.5078, 0.2812],
               [0.2539, 0.6992, 0.4453,  ..., 0.1172, 0.9648, 0.1094],
               [0.9883, 0.1914, 0.3633,  ..., 0.2148, 0.2266, 0.2930],
               ...,
               [0.5156, 0.8477, 0.5117,  ..., 0.0039, 0.1094, 0.6328],
               [0.5469, 0.9375, 0.4688,  ..., 0.8750, 0.9492, 0.0586],
               [0.4727, 0.8945, 0.5586,  ..., 0.3125, 0.9688, 0.6250]]]],
            dtype=torch.bfloat16)
torch.allclose: True

Please complete the following environment information:

ayerofieiev-tt commented 2 days ago

@kevinwuTT , can we attach this to any aten operation in scope of any of the pytorch sprints?

kevinwuTT commented 1 day ago

@ayerofieiev-tt I came across this issue with test_conv2d when removing layout changes and I got mismatching results. Our current aten.convolution lowering uses reshapes.

jvegaTT commented 10 hours ago

I just tested it and this issue is resolved in this PR https://github.com/tenstorrent/tt-metal/pull/15572