Open itsss opened 3 years ago
Can you please let me know how to get the predicted mask during validation? (after through ASPP)
I used this code(in test_frame.py) to get the predicted mask, but this code always gives GT for me.
for data in val_dataloader: begin_time = time.time() it = it+1 query_img, query_mask, support_img, support_mask, idx, size = data query_img, query_mask, support_img, support_mask, idx \ = query_img.cuda(), query_mask.cuda(), support_img.cuda(), support_mask.cuda(), idx.cuda() with torch.no_grad(): logits = model(query_img, support_img, support_mask) query_img = F.upsample(query_img, size=(size[0], size[1]), mode='bilinear') query_mask = F.upsample(query_mask, size=(size[0], size[1]), mode='nearest') print(query_mask.size()) values, pred = model.get_pred(logits, query_img) evaluations.update_evl(idx, query_mask, pred, 0) plt.figure() plt.subplot(2,2,1) plt.imshow(np.array(query_mask.squeeze().cpu()), cmap=cm.tab10_r) plt.subplot(2,2,2) plt.imshow(np.array(query_img.squeeze().permute(1,2,0).cpu()), cmap=cm.tab10_r) plt.axis('off') # plt.show() print(cnt) cnt = cnt + 1 plt.savefig("result/"+str(cnt)+".png") time.sleep(0.1)
Hi,
The predicted mask is the variable pred which is calculated by: values, pred = model.get_pred(logits, query_img)
pred
values, pred = model.get_pred(logits, query_img)
You can visualize it and have a try.
Can you please let me know how to get the predicted mask during validation? (after through ASPP)
I used this code(in test_frame.py) to get the predicted mask, but this code always gives GT for me.