xxxnell / how-do-vits-work

(ICLR 2022 Spotlight) Official PyTorch implementation of "How Do Vision Transformers Work?"
https://arxiv.org/abs/2202.06709
Apache License 2.0
798 stars 77 forks source link

Hi #22

Closed ross-Hr closed 1 year ago

ross-Hr commented 1 year ago

When i run the forward function of LocalAttention class, some errors occurred.

x.shape = [1,128,84,64] and self.window_size=8. The rearrange function can not run in the right way as n1=84//8 can not be divisible.

If i change the window_size=7/6/5, there may be other img's height or width can not be divisible.

I also try dynamic set window_size but it didn't succeed.

The image come from coco datasets.

Do you have any good suggestions ?

The code is

      b, c, h, w = x.shape

        p = self.window_size

        n1 = h // p

        n2 = w // p

        mask = torch.zeros(p ** 2, p ** 2, device=x.device) if mask is None else mask

        mask = mask + self.pos_embedding[self.rel_index[:, :, 0], self.rel_index[:, :, 1]]

        x = rearrange(x, "b c (n1 p1) (n2 p2) -> (b n1 n2) c p1 p2", p1=p, p2=p)

        x, attn = self.attn(x, mask)

        x = rearrange(x, "(b n1 n2) c p1 p2 -> b c (n1 p1) (n2 p2)", n1=n1, n2=n2, p1=p, p2=p)
xxxnell commented 1 year ago

Hi @ross-Hr !

Thank you for raising an interesting question. As you know, the ImageNet dataset doesn't suffer from the problems you pointed out, so I didn't take this issue too seriously. However, one straightforward solution is: Calculate self-attention only between boundary tokens. The figure below represents this strategy.

image

One approach to efficiently implementing this strategy is (1) to add (zero) padding tokens to the image tokens, and then (2) ignore the additional padding tokens by using self-attention masks.

ross-Hr commented 1 year ago

You're right ! I am realizing this idea. What I am doing now in my code likes:

if  n1 or n2  less than 2 ,
do  x, attn = self.attn(x, mask)  directly without split window.