Open swimdi opened 2 weeks ago
The problem may appeared on ttnn.from_torch
& ttnn.to_torch
, here is the origin graph of test_add_shape_mismatch
# x = torch.ones([5])
def forward(self, x):
add = torch.ops.aten.add.Tensor(x, 0.0)
return add
and here is the compiled graph
# arg0_1.shape: torch.Size([5])
def forward(self, arg0_1):
# ttnn_from_torch.shape: ttnn.Shape([1[32], 5[32]])
ttnn_from_torch = ttnn.from_torch(arg0_1, layout = ttnn.TILE_LAYOUT, device = device, dtype = ttnn.bfloat16); arg0_1 = None
# ttnn_add.shape: ttnn.Shape([1[32], 5[32]])
ttnn_add = ttnn.add(ttnn_from_torch, 0.0); ttnn_from_torch = None
# ttnn_to_torch.shape: torch.Size([1, 5])
ttnn_to_torch = ttnn.to_torch(ttnn_add); ttnn_add = None
# _to_copy_default.shape: torch.Size([1, 5])
_to_copy_default = torch.ops.aten._to_copy.default(ttnn_to_torch, dtype = torch.float32); ttnn_to_torch = None
return (_to_copy_default,)
ttnn_from_torch.shape
become ttnn.Shape([1[32], 5[32]])
and ttnn_to_torch.shape
is torch.Size([1, 5])
, it cannot convert back successfully
Retinanet has same issue, its err msg is
IndexError: shape mismatch: indexing tensors could not be broadcast together with shapes [1, 1, 800], [1, 1066]
And one of its root cause can reproduced by this pattern
# result_before.shape: torch.Size([800, 1]) (0~479)
# result_after.shape: torch.Size([1, 1, 800]) (0~480)
arange = torch.ops.aten.arange.default(800, dtype=torch.float32, device=torch.device(type="cpu"), pin_memory=False)
add = torch.ops.aten.add.Tensor(arange, 0.5)
mul = torch.ops.aten.mul.Tensor(add, 0.6)
sub_1 = torch.ops.aten.sub.Tensor(mul, 0.5)
clamp = torch.ops.aten.clamp.default(sub_1, 0.0)
_to_copy = torch.ops.aten._to_copy.default(clamp, dtype=torch.int64)
unsqueeze_6 = torch.ops.aten.unsqueeze.default(_to_copy, 1)
return unsqueeze_6
You can reproduce it by pytest tests/pattern/test_retinanet_pattern.py
Brief describe
add shape mismatch for this input variation
Run
pytest tests/pattern/test_add_shape_mismatch.py
can reproduceDebug log
In ViLT, there have the following code
The original value of
x_h
andx_w
areBut after convert to ttnn, the wrong value become
And it will cause
nn.functional.interpolate
failedAfter debugging, I think the root cause is the add operator's shape is different (correct shape should be [12], but got [1, 12])
Currently ViLT is work because the guard function guard it
If this issue can solve, then can also remove the blocklist on guard function