facebookresearch / sam2

The repository provides code for running inference with the Meta Segment Anything Model 2 (SAM 2), links for downloading the trained model checkpoints, and example notebooks that show how to use the model.
Apache License 2.0
12.44k stars 1.15k forks source link

Redundant arguments in Hiera ? #290

Open hugoWR opened 2 months ago

hugoWR commented 2 months ago

I'm a little confused by the default arguments passed to the Hiera class. https://github.com/facebookresearch/segment-anything-2/blob/7e1596c0b6462eb1d1ba7e1492430fed95023598/sam2/modeling/backbones/hieradet.py#L184

It seems like window_spec is hiding two things at once. In the first window attention blocks, They have respectively 8 and 4 tokens (That's window attention). After those, they will have 14 and 7 tokens (That's all the tokens left, so that is global attention).

But then it looks redundant with the global_att_blocks parameters.

Can you help me understand those 3 parameters, window_spec, global_att_blocks and window_pos_embed_bkg_spatial_size.

Thanks!

chayryali commented 2 months ago

Hi @hugoWR, at a resolution of 1024px, stages 1, 2, 3, 4 will have 2562, 1282, 642, 322 tokens, so 14 and 7 is still window attention and not global - while global_att_blocks will be global (so for e.g. a block from stage 3 listed in global_att_blocks will have an attention span of 642 and not 142).

At a resolution of 224px, you are correct that 14 and 7 for stages 3 and 4 are indeed global already and setting global_att_blocks in these stages would be redundant.

hugoWR commented 2 months ago

Thank you for the clarification, I have further question if you don't mind.

Thank you for your help!

MonolithFoundation commented 2 months ago

Also want consult an question, the Hiera det outputs 4 heirachy featuremaps, which one should be used, if we want use it for understanding?

heyoeyo commented 2 months ago

There is another paper by the authors that explains some of your questions I think, called: "Window Attention is Bugged".

Why is global attention used on those 3 blocks by default?

The short answer seems to be that windowing is faster and having 3 global layers was found to be 'good enough'. Speed-wise, on my machine, the base model takes 21ms when using just the three global blocks vs 27ms with all global blocks on stage 3 (the large model is 47ms vs. 73ms). The paper mentioned above talks about this in more detail and I assume they used the same approach for SAMv2. It's in the appendix under the section: HieraDet Ablations (page 16).

What is window_pos_embed_bkg_spatial_size parameter?

That paper describes a special position encoding that consists of a 'global' learned position encoding that is combined with a tiled 'window' encoding. The window_pos_embed_bkg_spatial_size is (confusingly, maybe a typo?) the size of the global part of that encoding. The windowed part gets it's size from the first stage windowing which is 8x8 for all models. The 14x14 size is only for the base model for some reason, all others (including the large model) are 7x7, no idea why it's done that way...

Hiera uses slightly less parameters than the one in the paper

Not sure about this one... The SAMv2 model includes an FpnNeck, maybe it replaces part of the original Hiera output layers?

Also want consult an question, the Hiera det outputs 4 heirachy featuremaps, which one should be used, if we want use it for understanding?

The SAMv2 model makes use of all 4 outputs, but the stage 3 one is the most important (the others can be ignored without breaking the model). So if you had to use only one, then the stage 3 (i.e. second-last) output is probably best.

MonolithFoundation commented 2 months ago

@heyoeyo Oh, I noticed that the output 0 1 2 actually the middle output in a auto regression manner, so 3 is the last one,

Why does the last second one is the most important one? Is it be the design or by the training?

torch.Size([1, 144, 224, 224])
torch.Size([1, 288, 112, 112])
torch.Size([1, 576, 56, 56]) <----- You mean this one should be used for understanding?
torch.Size([1, 1152, 28, 28])
heyoeyo commented 2 months ago

... so 3 is the last one

Yes sorry, I forgot about that detail. The Hiera model has 4 outputs, but the SAM image encoder only has 3 outputs because it has extra steps to merge the last 2 outputs. If using just Hiera, the output you highlighted is probably the single best one, otherwise if using the SAMv2 image encoder, then the 'last' (low-res) output is probably best.

Why does the last second one is the most important one? Is it be the design or by the training?

I'm not sure why they prioritize the 3rd stage specifically, it might just be to match SAMv1 (the 3rd stage of SAMv2 ends up with a size of 64x64, which is what the v1 model used). Though even in the original Hiera model (separate from SAMv2), the 4th stage always has very few blocks compared to the 3rd stage, which seems to prioritize the 3rd stage output (for example, the large models uses 36 blocks on stage 3, but only 4 for stage 4).

That being said, if you were going to make a classifier model, maybe the 4th stage output is best (it has a bit more processing steps and the low resolution doesn't matter so much), or if you needed the higher resolution of earlier stages they could be better choices. The 3rd stage just seems like the best option 'generally' because stage 4 has so few blocks and a reduced resolution.

MonolithFoundation commented 2 months ago

@heyoeyo thanks for your advise.

Am used hiera directly to experiment, and the 4th layer were used here.Since only the last layer outputs a reasonable tokens in for LLM to understanding.

I still want consult a question, does SAM2 plan to release the huge model? large is not so big for more advanced usage such as transfer SAM2's image encoder for OCR or image understanding.

MonolithFoundation commented 2 months ago

@heyoeyo Hi, i suddendly found sam2 has a FPN neck in imageEncoder, and it outputs position embedding.

Do u think is fpn neck also necesssary for image understanding? if so, does pos emebedding necessary send to LLM?

heyoeyo commented 2 months ago

I still want consult a question, does SAM2 plan to release the huge model?

I don't know for sure, but I wouldn't expect it, since they don't reference the huge model in the paper.

Do u think is fpn neck also necesssary for image understanding?

It might be helpful if you plan to swap out different SAM models (i.e. large vs. tiny), since the main job of it seems to be to project the image features to have a consistent number of channels (256) regardless of the model size. However, that could be harmful if you're planning to use only one model size and have some follow-up processing, since reducing the channel count could be throwing away a lot of the information.

if so, does pos emebedding necessary send to LLM?

That position embedding can be done independent of the FpnNeck and isn't learned (it comes from the PositionEmbeddingSine model), however the only use of it seems to be adding a small fraction back into the image features for the memory attention step. I have little experience with LLMs, so I'm not sure how helpful the extra encodings would be, though they don't do much for SAMv2 and they only get used when working with video data (so maybe not important if you're working with single images).

MonolithFoundation commented 2 months ago

@heyoeyo thank u so much for the insight!