tenstorrent / pytorch2.0_ttnn

⭐️ TTNN Compiler for PyTorch 2.0 ⭐️ It enables running PyTorch2.0 models on Tenstorrent hardware
https://tenstorrent.github.io/tt-metal/latest/ttnn/
25 stars 6 forks source link

ViLT failed because result shape differ from `aten.add` and `ttnn.add` #390

Open swimdi opened 2 weeks ago

swimdi commented 2 weeks ago

Brief describe

add shape mismatch for this input variation

"Tensor<[5]> self = ?", "Tensor other = 0.0"

Run pytest tests/pattern/test_add_shape_mismatch.py can reproduce

FAILED tests/pattern/test_add_shape_mismatch.py::test_add_vilt - assert torch.Size([5]) == torch.Size([1, 5])

Debug log

In ViLT, there have the following code

# /home/ubuntu/venv_pt/lib/python3.8/site-packages/transformers/models/vilt/modeling_vilt.py(122)<listcomp>()
    def visual_embed(self, pixel_values, pixel_mask, max_image_length=200):
        _, _, ph, pw = self.patch_embeddings.projection.weight.shape

        x = self.patch_embeddings(pixel_values)
        x_mask = pixel_mask[:, None, :, :].float()
        x_mask = nn.functional.interpolate(x_mask, size=(x.shape[2], x.shape[3])).long()
        x_h = x_mask[:, 0].sum(dim=1)[:, 0]
        x_w = x_mask[:, 0].sum(dim=2)[:, 0]

        batch_size, num_channels, height, width = x.shape
        patch_dim = self.config.image_size // self.config.patch_size
        spatial_pos = self.position_embeddings[:, 1:, :].transpose(1, 2).view(1, num_channels, patch_dim, patch_dim)
        pos_embed = torch.cat(
            [
                nn.functional.pad(
                    nn.functional.interpolate(
                        spatial_pos,
                        size=(h, w),
                        mode="bilinear",
                        align_corners=True,
                    ),
                    (0, width - w, 0, height - h),
                )
                for h, w in zip(x_h, x_w)
            ],
            dim=0,
        )

The original value of x_h and x_w are

x_h: TorchTensor([12])
x_w: TorchTensor([16])

But after convert to ttnn, the wrong value become

x_h: TorchTensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
x_w: TorchTensor([[12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12]])

And it will cause nn.functional.interpolate failed

>           return torch._C._nn.upsample_bilinear2d(input, output_size, align_corners, scale_factors)
E           TypeError: upsample_bilinear2d() received an invalid combination of arguments - got (Tensor, tuple, bool, NoneType), but expected one of:
E            * (Tensor input, tuple of ints output_size, bool align_corners, tuple of floats scale_factors)
E                 didn't match because some of the arguments have invalid types: (Tensor, !tuple of (TorchTensor, TorchTensor)!, bool, !NoneType!)
E            * (Tensor input, tuple of ints output_size, bool align_corners, float scales_h, float scales_w, *, Tensor out)

../venv_pt/lib/python3.8/site-packages/torch/nn/functional.py:4038: TypeError

After debugging, I think the root cause is the add operator's shape is different (correct shape should be [12], but got [1, 12])

Currently ViLT is work because the guard function guard it

If this issue can solve, then can also remove the blocklist on guard function

swimdi commented 2 weeks ago

The problem may appeared on ttnn.from_torch & ttnn.to_torch, here is the origin graph of test_add_shape_mismatch

   # x = torch.ones([5])
    def forward(self, x):
        add = torch.ops.aten.add.Tensor(x, 0.0)
        return add

and here is the compiled graph

# arg0_1.shape: torch.Size([5])
def forward(self, arg0_1):
    # ttnn_from_torch.shape: ttnn.Shape([1[32], 5[32]])
    ttnn_from_torch = ttnn.from_torch(arg0_1, layout = ttnn.TILE_LAYOUT, device = device, dtype = ttnn.bfloat16);  arg0_1 = None
    # ttnn_add.shape: ttnn.Shape([1[32], 5[32]])
    ttnn_add = ttnn.add(ttnn_from_torch, 0.0);  ttnn_from_torch = None
    # ttnn_to_torch.shape: torch.Size([1, 5])
    ttnn_to_torch = ttnn.to_torch(ttnn_add);  ttnn_add = None
    # _to_copy_default.shape: torch.Size([1, 5])
    _to_copy_default = torch.ops.aten._to_copy.default(ttnn_to_torch, dtype = torch.float32);  ttnn_to_torch = None
    return (_to_copy_default,)

ttnn_from_torch.shape become ttnn.Shape([1[32], 5[32]]) and ttnn_to_torch.shape is torch.Size([1, 5]), it cannot convert back successfully

swimdi commented 2 weeks ago

ghostnetv2_100 has same issue

swimdi commented 2 weeks ago

Retinanet has same issue, its err msg is

IndexError: shape mismatch: indexing tensors could not be broadcast together with shapes [1, 1, 800], [1, 1066]

And one of its root cause can reproduced by this pattern

# result_before.shape: torch.Size([800, 1]) (0~479)
# result_after.shape: torch.Size([1, 1, 800]) (0~480)
arange = torch.ops.aten.arange.default(800, dtype=torch.float32, device=torch.device(type="cpu"), pin_memory=False)
add = torch.ops.aten.add.Tensor(arange, 0.5)
mul = torch.ops.aten.mul.Tensor(add, 0.6)
sub_1 = torch.ops.aten.sub.Tensor(mul, 0.5)
clamp = torch.ops.aten.clamp.default(sub_1, 0.0)
_to_copy = torch.ops.aten._to_copy.default(clamp, dtype=torch.int64)
unsqueeze_6 = torch.ops.aten.unsqueeze.default(_to_copy, 1)
return unsqueeze_6

You can reproduce it by pytest tests/pattern/test_retinanet_pattern.py