xxxnell / how-do-vits-work

(ICLR 2022 Spotlight) Official PyTorch implementation of "How Do Vision Transformers Work?"
https://arxiv.org/abs/2202.06709
Apache License 2.0
798 stars 77 forks source link

How would the MSA build-up rules differ for upsampling stages? #29

Closed waitingcheung closed 1 year ago

waitingcheung commented 1 year ago
  1. Alternately replace Conv blocks with MSA blocks from the end of a baseline CNN model.
  2. If the added MSA block does not improve predictive performance, replace a Conv block located at the end of an earlier stage with an MSA.
  3. Use more heads and higher hidden dimensions for MSA blocks in late stages.

I suppose the above rules apply to high-level computer vision tasks such as classifications that involve only downsampling. I wonder how these rules differ for tasks involving upsampling stages such as image generation from latent or segmentation with U-Net. In particular, I am interested in (1) the ordering of Conv and MSA blocks and (2) the number of heads and hidden dimensions in upsampling stages.

Based on your findings that Convs are high-pass and MSAs are low-pass filters, I suppose the ordering of Conv-MSA blocks should hold for both downsampling and upsampling stages instead of MSA-Conv blocks in upsampling.

Since downsampling stages usually reduce the spatial resolution and increase the channel dimension, the third rule makes sense. However, upsampling stages usually increase the spatial resolution and reduce the channel dimension, does the third rule still hold for upsampling? Or should it be flipped to fewer heads and lower hidden dimensions for late stages?

I will appreciate your valuable insights on the application of these build-up rules for upsampling stages.

xxxnell commented 1 year ago

Hi @waitingcheung, thank you for the insightful question.

As you correctly pointed out, the building up rule is for encoders or backbone architectures that only include downsampling steps. Finding the appropriate building up rule for decoders is an open problem; decoders may not behave like encoders. For example, decoders tend to capture low-frequency information compared to encoders in some tasks. This may suggest that decoders needs more self-attention layers, more heads, and higher embedding dimensions, compared to encoders, in such cases.

Anyway, I vote for using fewer heads and embedding dimensions in later stages of decoders. Contrary to encoders, intuitively, I expect that later stages in decoders exploit low-level features and capture high-frequency information.