raoyongming / DynamicViT

[NeurIPS 2021] [T-PAMI] DynamicViT: Efficient Vision Transformers with Dynamic Token Sparsification
https://dynamicvit.ivg-research.xyz/
MIT License
576 stars 72 forks source link

[Questions] Dyswin feature output shape #25

Closed zafirshi closed 2 years ago

zafirshi commented 2 years ago

Hi~ Thx for your great work and excellent code!

I have some questions about dyswin code, hope you could help me out.

  1. In code: https://github.com/raoyongming/DynamicViT/blob/84b4e2a9b1f11199bd1e2ff506969b0d64e6f55b/models/dyswin.py#L683

    when I input a tensor with shape (2, 3, 224,224), the if condition is activated(len(x) is equal to 2, and the following operation obviously goes wrong.

    However, when I set batchsize to another number, this error is disappeared. Could you please explain the code here?

  2. When I use lvvit-s pre-trained model as inference backbone, I find that the output token length of lvvit-s is cut shorter compared with standard lvvit output token length, and it's correct, right?

    But when I change the backbone from lvvit-s to swin-s, the output token length is the same as standard swin-s

    For example, if input tensor with shape (4, 3,224,224), the output shape of dyswin(temporarily ignore the avgpool and later layers) is (4, 49, 768 ) while the standard swin-s is also output this shape tensor. It seems that token length reduction has not been achieved.

    Could you please explain the here? Any advice could be greatly appreciate! :)

liuzuyan commented 2 years ago

Hi, thanks for your interest in our work. For Q1, in our code, we separate the origin x into [x1, x2] to represent the important and unimportant tokens, therefore we use if len(x) == 2 to check whether we need to reassemble the tensor for the follow-up operations (e.g. downsample after this line). However, if the batch size is set to 2, len(x) function will return the shape of the first dimension if x is a tensor and thus return True. For other batch sizes, len(x) will return False. To fix this bug, we can check whether x is in the form of list.

For Q2, in lvvit-s, the token is removed so the output length will be shorter. However, as DynamicViT is designed for a vanilla vision Transformer, if we remove tokens in hierarchical architecture (swin-s or convnext-s), the shift window or convolution operations cannot be calculated. Therefore we design the asymmetric computation with fast and slow paths to accelerate without token reduction. You can refer to our newly released paper (arXiv) for details.

zafirshi commented 2 years ago

Thank you for your detailed explanation so much ! :)