Yang-Bob / PMMs

Prototype Mixture Models for Few-shot Semantic Segmentation
163 stars 27 forks source link

How to get predicted mask #21

Open itsss opened 3 years ago

itsss commented 3 years ago

Can you please let me know how to get the predicted mask during validation? (after through ASPP)

I used this code(in test_frame.py) to get the predicted mask, but this code always gives GT for me.


for data in val_dataloader:
            begin_time = time.time()
            it = it+1
            query_img, query_mask, support_img, support_mask, idx, size = data

            query_img, query_mask, support_img, support_mask, idx \
                = query_img.cuda(), query_mask.cuda(), support_img.cuda(), support_mask.cuda(), idx.cuda()

            with torch.no_grad():
                logits = model(query_img, support_img, support_mask)

                query_img = F.upsample(query_img, size=(size[0], size[1]), mode='bilinear')
                query_mask = F.upsample(query_mask, size=(size[0], size[1]), mode='nearest')
                print(query_mask.size())

                values, pred = model.get_pred(logits, query_img)
                evaluations.update_evl(idx, query_mask, pred, 0)

                plt.figure()
                plt.subplot(2,2,1)
                plt.imshow(np.array(query_mask.squeeze().cpu()), cmap=cm.tab10_r)
                plt.subplot(2,2,2)
                plt.imshow(np.array(query_img.squeeze().permute(1,2,0).cpu()), cmap=cm.tab10_r)
                plt.axis('off')
                # plt.show()
                print(cnt)
                cnt = cnt + 1
                plt.savefig("result/"+str(cnt)+".png")
                time.sleep(0.1)
Yang-Bob commented 3 years ago

Hi,

The predicted mask is the variable pred which is calculated by: values, pred = model.get_pred(logits, query_img)

You can visualize it and have a try.