HuiqianLi / ASPS

[MICCAI 2024] Repository for "ASPS: Augmented Segment Anything Model for Polyp Segmentation"
21 stars 2 forks source link

对比实验 #3

Closed 9Cirno9 closed 2 months ago

9Cirno9 commented 2 months ago

前辈您好! 论文中有提到在SAMUS上对息肉数据集进行了测评,我在尝试把本文数据集放到SAMUS上跑的时候,由于它采用了prompt等原因,出现了许多错误。请问您当时怎么更改SAMUS训练代码的呢?如果方便的话可以发送到 2870349402@qq.com,不胜感激!

HuiqianLi commented 2 months ago

你好,我大概改了这个位置的代码 SAMUS/segment_anything_samus/modeling/samus.py 具体不太记得了,我回忆应该就是把pti设为None,然后就按照SAM的没有提示的方式返回 sparse_prompt_embeddings 和 dense_prompt_embeddings。仅供参考。

    def forward(
        self, 
        imgs: torch.Tensor,
        pt: Tuple[torch.Tensor, torch.Tensor]=None,  # [b n 2, b n]
        bbox: torch.Tensor=None, # b 4
    ) -> torch.Tensor:
        imge= self.image_encoder(imgs)
        # if len(pt[0].shape) == 3:
        #   se, de = self.prompt_encoder(            # se b 2 256, de b 256 32 32
        #                 points=pt,
        #                 boxes=None,
        #                 masks=None,
        #             )
        #   low_res_masks, _ = self.mask_decoder( # low_res_mask b 1 128 128
        #             image_embeddings=imge,
        #             image_pe=self.prompt_encoder.get_dense_pe(), 
        #             sparse_prompt_embeddings=se,
        #             dense_prompt_embeddings=de, 
        #             multimask_output=False,
        #             )
        #   masks = F.interpolate(low_res_masks, (256, 256), mode="bilinear", align_corners=False)
        #   outputs = {"low_res_logits": low_res_masks, "masks": masks}
        #   return outputs
        # else:
        if pt is None:
          low_res_masks, masks = [], []
          # for i in range(pt[0].shape[1]):
          for i in range(1):
            # pti = (pt[0][:, i, :, :], pt[1][:, i, :])
            sei, dei = self.prompt_encoder(            # se b 2 256, de b 256 32 32
                        # points=pti,
                        points=None,
                        boxes=None,
                        masks=None,
                    )
            low_res_masksi, _ = self.mask_decoder( # low_res_mask b 1 128 128
                    image_embeddings=imge,
                    image_pe=self.prompt_encoder.get_dense_pe(), 
                    sparse_prompt_embeddings=sei,
                    dense_prompt_embeddings=dei, 
                    multimask_output=False,
                    )
            masksi = F.interpolate(low_res_masksi, (256, 256), mode="bilinear", align_corners=False)
            low_res_masks.append(low_res_masksi)
            masks.append(masksi)
          low_res_masks = torch.stack(low_res_masks, dim=1)
          masks = torch.stack(masks, dim=1) # b c 1 255 255
          masks = masks.reshape(masks.shape[0], -1, masks.shape[3], masks.shape[4])
          low_res_masks = low_res_masks.reshape(low_res_masks.shape[0], -1, low_res_masks.shape[3], low_res_masks.shape[4])
          outputs = {"low_res_logits": low_res_masks, "masks": masks}
          return outputs