Open rfurko-tt opened 1 month ago
Describe the bug Repeat uses legacy shape not a real one. Unexpected results.
To Reproduce
import ttnn import torch import numpy as np with ttnn.manage_device(device_id=0) as device: x = torch.ones((1, 1, 1, 1), dtype=torch.float32) y = torch.ones((1, 1, 1, 128), dtype=torch.float32) x_tt = ttnn.from_torch(x, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device) y_tt = ttnn.from_torch(y, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device) z_tt = ttnn.repeat(x_tt, y_tt.shape) z = ttnn.to_torch(z_tt) print(z.shape) print(z_tt.shape)
Output:
torch.Size([1, 1, 1, 4065]) ttnn.Shape([1, 1, 1[32], 4065[4096]])
Expected behavior ttnn.Shape([1, 1, 1[32], 128])
Please complete the following environment information:
Hi @tarafdarTT, Could you please take a look? Thanks in advance!
@jaykru-tt
Describe the bug Repeat uses legacy shape not a real one. Unexpected results.
To Reproduce
Output:
Expected behavior ttnn.Shape([1, 1, 1[32], 128])
Please complete the following environment information: