hustvl / EVF-SAM

Official code of "EVF-SAM: Early Vision-Language Fusion for Text-Prompted Segment Anything Model"
Apache License 2.0
320 stars 14 forks source link

variable clarification in forward() function #31

Open GinnyXiao opened 1 month ago

GinnyXiao commented 1 month ago

Dear authors,

I wanted to express my gratitude to you again! Your work immensely inspired me.

I was wondering if you could kindly explain the relationship between the variables offset, batch_size, and len(feat)? What does offset do and why does batch_size == len(offset) - 1? Does len(feat) equal to batch_size?

Also, from your code I understand that BEiT3 can process a batch of image-text inputs, but SAM 2 does not support batch processing? (You used a for-loop.) For example, can SAM support parallel processing of:

  1. one image input, a batch of N prompts that correspond to N different object, or
  2. N image-prompt pairs?
    def forward(
        self,
        images: torch.FloatTensor,
        images_evf: torch.FloatTensor,
        input_ids: torch.LongTensor,
        attention_masks: torch.LongTensor,
        offset: torch.LongTensor,
        masks_list: List[torch.FloatTensor],
        label_list: List[torch.Tensor],
        resize_list: List[tuple],
        inference: bool = False,
        **kwargs,
    ):
        # image_embeddings = self.get_visual_embs(images)     
        backbone_out = self.visual_model.forward_image(images)
        # dict_keys(['vision_features', 'vision_pos_enc', 'backbone_fpn'])
        _, image_embeddings, _, _ = self.visual_model._prepare_backbone_features(backbone_out)
        image_embeddings = [_.to(images.dtype) for _ in image_embeddings]
        batch_size = images.shape[0]
        if self.visual_model.directly_add_no_mem_embed:
            image_embeddings[-1] = image_embeddings[-1] + self.visual_model.no_mem_embed

        feats = [
            feat.permute(1, 2, 0).view(batch_size, -1, *feat_size)
            for feat, feat_size in zip(image_embeddings[::-1], self._bb_feat_sizes[::-1])
        ][::-1]
        _features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]}

        assert batch_size == len(offset) - 1

        images_evf_list = []
        for i in range(len(offset) - 1):
            start_i, end_i = offset[i], offset[i + 1]
            images_evf_i = (
                images_evf[i]
                .unsqueeze(0)
                .expand(end_i - start_i, -1, -1, -1)
                .contiguous()
            )
            images_evf_list.append(images_evf_i)
        images_evf = torch.cat(images_evf_list, dim=0)

        output = self.mm_extractor.beit3(
            visual_tokens=images_evf, 
            textual_tokens=input_ids, 
            text_padding_position=~attention_masks
            )

        # retrieve the [CLS] token as the output multimodal embeddings.
        feat = output["encoder_out"][:, :1, ...]

        pred_masks = []

        for i in range(len(feat)):
            ...
CoderZhangYx commented 1 month ago

For the explanation of offset, take a look at #23 . Based on the explanation of offset, you may understand that we use for-loop for inferencing SAM because items of feat has different shapes and can't be concated to a tensor for batch inference.