Open TritiumR opened 3 hours ago
part of my code
# prepare support images and masks
for idx, image_name in enumerate(ref_images):
support_img = cv2.imread(os.path.join(image_path, 'ref_images', image_name))
support_img = cv2.resize(support_img, (518, 518), interpolation=cv2.INTER_NEAREST)
support_img = torch.tensor(np.array(support_img)).float()
support_img = support_img.transpose(0, 2).transpose(1, 2)
support_anno = cv2.imread(os.path.join(image_path, 'ref_annotations', image_name.replace('.JPEG', '.png')), cv2.IMREAD_GRAYSCALE)
support_anno = cv2.resize(support_anno, (518, 518), interpolation=cv2.INTER_NEAREST)
support_anno = torch.tensor(np.array(support_anno))
for part_id in range(args.part_num):
support_mask = (support_anno == part_id).float()
# print('support_mask', support_mask.min(), support_mask.max(), support_mask.sum())
support_annos[part_id].append(support_mask)
support_imgs.append(support_img)
support_imgs = torch.stack(support_imgs, dim=0)
# Testing
for idx, image_name in enumerate(images):
query_img = cv2.imread(os.path.join(image_path, 'images', image_name))
query_img = cv2.resize(query_img, (518, 518), interpolation=cv2.INTER_NEAREST)
query_img = torch.tensor(np.array(query_img)).float()
query_img = query_img.transpose(0, 2).transpose(1, 2).unsqueeze(0)
query_anno = cv2.imread(os.path.join(image_path, 'annotations', image_name.replace('.JPEG', '.png')), cv2.IMREAD_GRAYSCALE)
query_anno = cv2.resize(query_anno, (518, 518), interpolation=cv2.INTER_NEAREST)
query_anno = torch.tensor(np.array(query_anno))
for part_id in range(args.part_num):
# 1. Matcher prepare references
support_masks = torch.stack(support_annos[part_id], dim=0)
print('support_masks', support_masks.size(), support_masks.min(), support_masks.max(), support_masks.sum())
matcher.set_reference(support_imgs[None].to(args.device), support_masks[None].to(args.device))
# 2. Matcher prepare target
matcher.set_target(query_img.to(args.device))
# 3. Predict mask of target
pred_mask = matcher.predict()
matcher.clear()`
Thanks for the great work.
I am trying to use custom data for testing and I am not sure if it's right since the results looks odd.
What's the correct format for the "support_masks"? Currently I have segmentation masks with value 0-n representing n+1 parts. How to convert them into the input "support_masks"?
Thank you for any help.