Hzzone / PseCo

(CVPR 2024) Point, Segment and Count: A Generalized Framework for Object Counting
77 stars 6 forks source link

Mask Decoder Output #9

Closed jwahnn closed 3 months ago

jwahnn commented 3 months ago

Based on the diagram from the paper, it seems like there are mask proposals outputted from the mask decoder. If this is the case, I want to save these masks somewhere, but in the demo_in_the_wild.ipynb file, I can't really tell where the mask outputs are at. Can you locate me to this point?

Hzzone commented 3 months ago
sam.forward_sam_with_embeddings(features, points=pred_points[indices])
jwahnn commented 3 months ago

Looked into the forward_sam_with_embeddings function and got myself to output the masks, but it seems like I am saving more masks than I need to?

Where in the code are these masks classified so that only the relevant ones remain?

Hzzone commented 3 months ago

The masks will be classified with cls_outs_ = cls_head(features, [pred_boxes, ], [example_features, ] * len(indices)) and we can set a threshold to filter the masks.

jwahnn commented 3 months ago

I was trying to get the masks using the following lines of code:

            outputs_points = sam.forward_sam_with_embeddings(features, points=pred_points[indices])

            low_res_masks = outputs_points['pred_logits']
            print("One: ", type(low_res_masks), len(low_res_masks))
            masks_np = low_res_masks.cpu().numpy()
            output_dir = 'saved_masks'
            os.makedirs(output_dir, exist_ok=True)

            for i in range(masks_np.shape[0]):
                for j in range(masks_np.shape[1]):
                    mask = masks_np[i, j]
                    mask = (mask > 0.5).astype(np.uint8)
                    plt.imsave(f"{output_dir}/mask_{i}_{j}.png", mask, cmap='gray')

First, I am not sure if this is the intended behavior because it seems quite confusing for me to understand how the shape of mask_np relates to the actual masks.

Second, the length of outputs_points['pred_logits] and pred_logits after the line you directed me towards is same at 46. Hence, I am quite confused as to what I am doing wrong.

cls_outs_ = cls_head(features, [pred_boxes, ], [example_features, ] * len(indices))
cls_outs_ = cls_outs_.sigmoid().view(-1, len(example_features), 5).mean(1)
pred_logits = cls_outs_ * pred_logits
print("One: ", type(pred_logits), len(pred_logits))
Hzzone commented 3 months ago

Sorry, I do not understand your target.