hustvl / EVF-SAM

Official code of "EVF-SAM: Early Vision-Language Fusion for Text-Prompted Segment Anything Model"
Apache License 2.0
69 stars 0 forks source link

About finetune SAM backbone #2

Closed wysnzzzz closed 2 weeks ago

wysnzzzz commented 3 weeks ago

Good job! Have you tried to let go of the backbone of the SAM?

CoderZhangYx commented 3 weeks ago

Thanks for your attention. We maintain the architecture of SAM for below reasons:

  1. Since Referring Expression Segmentation data is limited, we want to make the best of SAM pretraining, where SAM image encoder holds most of the knowledge.
  2. The multi-modal feature extractor processes images resized to 224*224. We believe one other image encoder processing images of 1024*1024 (e.g. SAM image encoder) is better for segmentation tasks which need fine-grained feature map.
  3. We hope our design can be compatiable with any SAM-like models (e.g. EfficientSAM, SAM-HQ or any other further works).

As a result, we didn't try to let go of the backbone of SAM. You may try to do so, while several problems need to be cared:

  1. The downsampled image feature from the multi-modal feature extractor is of sequence length 196, which is different from that of SAM image encoder (4096). The SAM mask decoder may need some modification and the pretrained weights may fail.
  2. The multi-modal feature extractor may need further training for extracting fine-grained image feature.
qiulesun commented 2 weeks ago

@CoderZhangYx

Nice work! Since Referring Expression Segmentation data is limited, have you considered freezing the both SAM image encoder and Multimodal Encoder for fast training? If yes, can you give some comparion results?

CoderZhangYx commented 2 weeks ago

@CoderZhangYx

Nice work! Since Referring Expression Segmentation data is limited, have you considered freezing the both SAM image encoder and Multimodal Encoder for fast training? If yes, can you give some comparion results?

Yes, we have tried to freeze both SAM image encoder and the multi-modal feature extractor and find that the model fail to converge. You can see detailed comparision in our second ablation in our paper. The BEiT-3 is pretrained through MLM on huge amounts of image-text pairs, which focuses on macro interpretation. However, segmentation tasks demand the model to focus more on specific Region of Interest of the provided image. Finetuning the multi-modal feature extractor is necessary to reallocate its feature preference.

qiulesun commented 6 days ago

@CoderZhangYx

The late fusion offers significant improvments in Table 3. So, could you release a code snippet to achieve the late fusion when CLIP is used as Multimodal Encoder?

CoderZhangYx commented 6 days ago

You can modify evf_sam.py

from transformers import CLIPModel

...

self.mm_extractor = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")

...

input_ids = torch.cat([input_ids, torch.full((input_ids.shape[0], 77-input_ids.shape[1]), 49407 ,dtype=input_ids.dtype, device=input_ids.device)], dim=1)
attention_masks = torch.cat([attention_masks, torch.full((attention_masks.shape[0], 77-attention_masks.shape[1]), 0 ,dtype=attention_masks.dtype, device=attention_masks.device)], dim=1)

text_feat = self.mm_extractor.get_text_features(input_ids, attention_masks).unsqueeze(1)
img_feat = self.mm_extractor.get_image_features(images_clip).unsqueeze(1)

feat = late_fuse(text_feat, img_feat)

By the way, please open another issue for any other questions.

qiulesun commented 6 days ago

@CoderZhangYx

Thank you for your prompt reply.

For this line feat = late_fuse(text_feat, img_feat), could you give a code snippe for achieving the late_fuse() ?

CoderZhangYx commented 6 days ago

You can simply try

feat = torch.cat([text_feat, img_feat], dim=1)

or

feat = text_feat + img_feat

or any other self-designed fusion structures