Closed wysnzzzz closed 2 weeks ago
Thanks for your attention. We maintain the architecture of SAM for below reasons:
As a result, we didn't try to let go of the backbone of SAM. You may try to do so, while several problems need to be cared:
@CoderZhangYx
Nice work! Since Referring Expression Segmentation data is limited, have you considered freezing the both SAM image encoder and Multimodal Encoder for fast training? If yes, can you give some comparion results?
@CoderZhangYx
Nice work! Since Referring Expression Segmentation data is limited, have you considered freezing the both SAM image encoder and Multimodal Encoder for fast training? If yes, can you give some comparion results?
Yes, we have tried to freeze both SAM image encoder and the multi-modal feature extractor and find that the model fail to converge. You can see detailed comparision in our second ablation in our paper. The BEiT-3 is pretrained through MLM on huge amounts of image-text pairs, which focuses on macro interpretation. However, segmentation tasks demand the model to focus more on specific Region of Interest of the provided image. Finetuning the multi-modal feature extractor is necessary to reallocate its feature preference.
@CoderZhangYx
The late fusion offers significant improvments in Table 3. So, could you release a code snippet to achieve the late fusion when CLIP is used as Multimodal Encoder?
You can modify evf_sam.py
from transformers import CLIPModel
...
self.mm_extractor = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
...
input_ids = torch.cat([input_ids, torch.full((input_ids.shape[0], 77-input_ids.shape[1]), 49407 ,dtype=input_ids.dtype, device=input_ids.device)], dim=1)
attention_masks = torch.cat([attention_masks, torch.full((attention_masks.shape[0], 77-attention_masks.shape[1]), 0 ,dtype=attention_masks.dtype, device=attention_masks.device)], dim=1)
text_feat = self.mm_extractor.get_text_features(input_ids, attention_masks).unsqueeze(1)
img_feat = self.mm_extractor.get_image_features(images_clip).unsqueeze(1)
feat = late_fuse(text_feat, img_feat)
By the way, please open another issue for any other questions.
@CoderZhangYx
Thank you for your prompt reply.
For this line feat = late_fuse(text_feat, img_feat), could you give a code snippe for achieving the late_fuse() ?
You can simply try
feat = torch.cat([text_feat, img_feat], dim=1)
or
feat = text_feat + img_feat
or any other self-designed fusion structures
Good job! Have you tried to let go of the backbone of the SAM?