aim-uofa / Matcher

[ICLR'24] Matcher: Segment Anything with One Shot Using All-Purpose Feature Matching
https://arxiv.org/abs/2305.13310
Other
448 stars 26 forks source link

Input mask format #35

Open TritiumR opened 3 hours ago

TritiumR commented 3 hours ago

Thanks for the great work.

I am trying to use custom data for testing and I am not sure if it's right since the results looks odd.

What's the correct format for the "support_masks"? Currently I have segmentation masks with value 0-n representing n+1 parts. How to convert them into the input "support_masks"?

Thank you for any help.

TritiumR commented 2 hours ago

part of my code

# prepare support images and masks
for idx, image_name in enumerate(ref_images):
    support_img = cv2.imread(os.path.join(image_path, 'ref_images', image_name))
    support_img = cv2.resize(support_img, (518, 518), interpolation=cv2.INTER_NEAREST)
    support_img = torch.tensor(np.array(support_img)).float()
    support_img = support_img.transpose(0, 2).transpose(1, 2)

    support_anno = cv2.imread(os.path.join(image_path, 'ref_annotations', image_name.replace('.JPEG', '.png')), cv2.IMREAD_GRAYSCALE)
    support_anno = cv2.resize(support_anno, (518, 518), interpolation=cv2.INTER_NEAREST)
    support_anno = torch.tensor(np.array(support_anno))

    for part_id in range(args.part_num):
        support_mask = (support_anno == part_id).float()
        # print('support_mask', support_mask.min(), support_mask.max(), support_mask.sum())
        support_annos[part_id].append(support_mask)

    support_imgs.append(support_img)

support_imgs = torch.stack(support_imgs, dim=0)

# Testing
for idx, image_name in enumerate(images):
    query_img = cv2.imread(os.path.join(image_path, 'images', image_name))
    query_img = cv2.resize(query_img, (518, 518), interpolation=cv2.INTER_NEAREST)
    query_img = torch.tensor(np.array(query_img)).float()
    query_img = query_img.transpose(0, 2).transpose(1, 2).unsqueeze(0)

    query_anno = cv2.imread(os.path.join(image_path, 'annotations', image_name.replace('.JPEG', '.png')), cv2.IMREAD_GRAYSCALE)
    query_anno = cv2.resize(query_anno, (518, 518), interpolation=cv2.INTER_NEAREST)
    query_anno = torch.tensor(np.array(query_anno))

    for part_id in range(args.part_num):
        # 1. Matcher prepare references
        support_masks = torch.stack(support_annos[part_id], dim=0)
        print('support_masks', support_masks.size(), support_masks.min(), support_masks.max(), support_masks.sum())
        matcher.set_reference(support_imgs[None].to(args.device), support_masks[None].to(args.device))

        # 2. Matcher prepare target
        matcher.set_target(query_img.to(args.device))

        # 3. Predict mask of target
        pred_mask = matcher.predict()
        matcher.clear()`