Closed 9Cirno9 closed 2 months ago
你好,我大概改了这个位置的代码 SAMUS/segment_anything_samus/modeling/samus.py 具体不太记得了,我回忆应该就是把pti设为None,然后就按照SAM的没有提示的方式返回 sparse_prompt_embeddings 和 dense_prompt_embeddings。仅供参考。
def forward(
self,
imgs: torch.Tensor,
pt: Tuple[torch.Tensor, torch.Tensor]=None, # [b n 2, b n]
bbox: torch.Tensor=None, # b 4
) -> torch.Tensor:
imge= self.image_encoder(imgs)
# if len(pt[0].shape) == 3:
# se, de = self.prompt_encoder( # se b 2 256, de b 256 32 32
# points=pt,
# boxes=None,
# masks=None,
# )
# low_res_masks, _ = self.mask_decoder( # low_res_mask b 1 128 128
# image_embeddings=imge,
# image_pe=self.prompt_encoder.get_dense_pe(),
# sparse_prompt_embeddings=se,
# dense_prompt_embeddings=de,
# multimask_output=False,
# )
# masks = F.interpolate(low_res_masks, (256, 256), mode="bilinear", align_corners=False)
# outputs = {"low_res_logits": low_res_masks, "masks": masks}
# return outputs
# else:
if pt is None:
low_res_masks, masks = [], []
# for i in range(pt[0].shape[1]):
for i in range(1):
# pti = (pt[0][:, i, :, :], pt[1][:, i, :])
sei, dei = self.prompt_encoder( # se b 2 256, de b 256 32 32
# points=pti,
points=None,
boxes=None,
masks=None,
)
low_res_masksi, _ = self.mask_decoder( # low_res_mask b 1 128 128
image_embeddings=imge,
image_pe=self.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sei,
dense_prompt_embeddings=dei,
multimask_output=False,
)
masksi = F.interpolate(low_res_masksi, (256, 256), mode="bilinear", align_corners=False)
low_res_masks.append(low_res_masksi)
masks.append(masksi)
low_res_masks = torch.stack(low_res_masks, dim=1)
masks = torch.stack(masks, dim=1) # b c 1 255 255
masks = masks.reshape(masks.shape[0], -1, masks.shape[3], masks.shape[4])
low_res_masks = low_res_masks.reshape(low_res_masks.shape[0], -1, low_res_masks.shape[3], low_res_masks.shape[4])
outputs = {"low_res_logits": low_res_masks, "masks": masks}
return outputs
前辈您好! 论文中有提到在SAMUS上对息肉数据集进行了测评,我在尝试把本文数据集放到SAMUS上跑的时候,由于它采用了prompt等原因,出现了许多错误。请问您当时怎么更改SAMUS训练代码的呢?如果方便的话可以发送到 2870349402@qq.com,不胜感激!