tenstorrent / tt-metal

:metal: TT-NN operator library, and TT-Metalium low level kernel programming model.
Apache License 2.0
464 stars 72 forks source link

[Bug Report] Strange behavior of ttnn.repeat #12680

Open rfurko-tt opened 1 month ago

rfurko-tt commented 1 month ago

Describe the bug Repeat uses legacy shape not a real one. Unexpected results.

To Reproduce

import ttnn
import torch
import numpy as np

with ttnn.manage_device(device_id=0) as device:
    x = torch.ones((1, 1, 1, 1), dtype=torch.float32)
    y = torch.ones((1, 1, 1, 128), dtype=torch.float32)

    x_tt = ttnn.from_torch(x, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
    y_tt = ttnn.from_torch(y, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)

    z_tt = ttnn.repeat(x_tt, y_tt.shape)
    z = ttnn.to_torch(z_tt)
    print(z.shape)
    print(z_tt.shape)

Output:

torch.Size([1, 1, 1, 4065])
ttnn.Shape([1, 1, 1[32], 4065[4096]])

Expected behavior ttnn.Shape([1, 1, 1[32], 128])

Please complete the following environment information:

rfurko-tt commented 1 month ago

Hi @tarafdarTT, Could you please take a look? Thanks in advance!

dmakoviichuk-tt commented 1 month ago

@jaykru-tt