facebookresearch / hiera

Hiera: A fast, powerful, and simple hierarchical vision transformer.
Apache License 2.0
913 stars 45 forks source link

Differences in implementation between SAM2 version and this implementation #37

Open hugoWR opened 2 months ago

hugoWR commented 2 months ago

Hello,

Thank you for the interesting code. I noticed the model looks quite different when compared with what is in SAM2 repo: https://github.com/facebookresearch/segment-anything-2/blob/7e1596c0b6462eb1d1ba7e1492430fed95023598/sam2/modeling/backbones/hieradet.py#L167

Would you be able to provide an explanation for the differences ? And why those are needed ?

Thank you in advance!

dbolya commented 2 months ago

Hi, could you point out some things you see are different? Both repos should contain the same model, just implemented differently. For instance, this repo unrolls the tokens at the start to be able to do fast window attention without reshaping. The SAM2 repo on the other hand opts to reshape windows at every layer instead, but the result is identical.

The main actual differences should be the position embedding (as SAM2 implements absolute win embeddings from "Window Attention Is Bugged"), and support for larger windows (from the ViTDet paper).

hugoWR commented 2 months ago

Hi,

Thank you for your quick response!

For example, SAM2's variant of Hiera uses global attention on blocks 12, 16 and 20 by defaults.

It is different than this implementation which uses global attention in all the later blocks on the model.

Do you know why this change was made ?

Thanks!

dbolya commented 2 months ago

SAM 2 implements HieraDet from "Window Attention is Bugged". This build upon ViTDet which takes a ViT and applies it for detection.

But also note: https://github.com/facebookresearch/segment-anything-2/blob/7e1596c0b6462eb1d1ba7e1492430fed95023598/sam2/modeling/backbones/hieradet.py#L184-L188

In the SAM 2 implementation, the window size for stages 3 and 4 are "14x14" and "7x7" respectively. In the original Hiera (this repo), that's the total number of tokens in those layers when using the default image size of 224 (196 -> 49 tokens). So if run on 224px images, all layers are already global in the SAM 2 implementation, the same as this repo.

But of course for much bigger images (e.g., 1024x1024 which is default for ViT/HieraDet), this allows you to effectively train the model downstream for detection while not destroying the attention / position embedding learned during pretraining (by passing it many more tokens than it was pretrained on).

hugoWR commented 2 months ago

Okay this makes sense, but why do global attention on those 3 specific blocks only (default to 12, 16, 20) ? It does not seem to be related to the end of the stages.

dbolya commented 2 months ago

See the ablations in Window Attention is Bugged Table 19d (in the appendix). They're all in stage 3, equally dispersed. Note the ones in the appendix are for Hiera-L (48 layers, 36 in stage 3), but the 12, 16, 20 is for Hiera-B which only has 24 layers (16 in stage 3), and we just scaled down the positions proportionately within stage 3: ((x-8)/36*16+5) for x=23, 33, 43 -> approx 12, 16, 20.

hugoWR commented 2 months ago

Thanks, I missed this.

I greatly appreciate your efforts in answering those questions :)

I will resolve this.

hugoWR commented 2 months ago

I'll re-open this briefly for a couple more questions.

Back to the SAM2 implementation we saw that window_spec was set to (8, 4, 14, 7) by default. But that setting feels like it would be incompatible with MAE.

My understanding was that you need to treat each Masked Units as an individual entity, so you need to have a decreasing window size because you have some pooling layers.

For the MAE pre-training step, you probably can't have an increase in the window size. So you should follow the recipe of this repo (First 2 stages with MU attention, last 2 with global attention).

For the fine-tuning step, you should follow SAM2 implementation to save on memory usage.

Is my understanding correct ?

Thanks!

dbolya commented 2 months ago

Hi, your understanding is correct, but just to be clear, I want to reiterate what the difference is. Like I said before--MAE pretrains with 224x224 px images. The initial downsample for Hiera is 4px, so stage 1 has 224/4 = 56x56 tokens. Then stage 2 has 28x28, stage 3 has 14x14 tokens, and finally stage 4 has 7x7 tokens. Notice how the window spec in SAM2 is 14x14 for stage 3 and 7x7 for stage 4--equal to the total number of tokens in those stages. So presuming you don't drop any tokens, for 224px images it is performing global attention.

The distinction between "full image window attention" and "global attention" does matter when you're dropping tokens like you point out--and thus you have to make stages 3 and 4 global (but I just wanted to be clear that without dropping, they're the same thing). Also, if you'd like to pretrain with a higher resolution, you should modify the downstream SAM2 window sizes such that stages 3 and 4 have a window size equal to the number of tokens in those stages during pretraining (e.g., if you PT with 384, window spec would be [8, 4, 24, 12]).

But yes, in general we set stages 3 and 4 to be global during PT and to have windows of the same resolution during high res downstream applications.