Closed sarathsrk closed 4 years ago
Here is my inference code, It would be helpful if someone corrects me
` unet_path = os.path.join(self.model_path, 'U_Net-100-0.0003-14-0.4129.pkl')
# U-Net Train
if os.path.isfile(unet_path):
# Load the pretrained Encoder
self.unet.load_state_dict(torch.load(unet_path))
print('%s is Successfully Loaded from %s'%(self.model_type,unet_path))
#else:
# Train for Encoder
#lr = self.lr
#best_unet_score = 0.
#===================================== Test ====================================#
#del self.unet
#del best_unet
print("test running")
#self.unet.test()
self.build_model()
self.unet.load_state_dict(torch.load(unet_path))
self.unet.train(False)
self.unet.eval()
acc = 0. # Accuracy
SE = 0. # Sensitivity (Recall)
SP = 0. # Specificity
PC = 0. # Precision
F1 = 0. # F1 Score
JS = 0. # Jaccard Similarity
DC = 0. # Dice Coefficient
length=0
for i, (images, GT) in enumerate(self.test_loader):
images = images.to(self.device)
GT = GT.to(self.device)
SR = self.unet(images)
acc += get_accuracy(SR,GT)
SE += get_sensitivity(SR,GT)
SP += get_specificity(SR,GT)
PC += get_precision(SR,GT)
F1 += get_F1(SR,GT)
JS += get_JS(SR,GT)
DC += get_DC(SR,GT)
length += images.size(0)
print('ACC:{}, SE:{}, SP:{},PC:{},F1:{},JS:{},DC:{}'.format(acc,SE,SP,PC,F1,JS,DC))
torchvision.utils.save_image(images.data.cpu(), os.path.join(self.result_path,'%s_valid_%d_image.png'))
torchvision.utils.save_image(SR.data.cpu(),os.path.join(self.result_path,'%s_test_%d_SR.png'))
torchvision.utils.save_image(GT.data.cpu(), os.path.join(self.result_path,'%s_test_%d_GT.png'))
#print("result path = ", self.result_path)
acc = acc/length
SE = SE/length
SP = SP/length
PC = PC/length
F1 = F1/length
JS = JS/length
DC = DC/length
unet_score = JS + DC
f = open(os.path.join(self.result_path,'result.csv'), 'a', encoding='utf-8', newline='')
wr = csv.writer(f)
wr.writerow([self.model_type,acc,SE,SP,PC,F1,JS,DC,self.lr,best_epoch,self.num_epochs,self.num_epochs_decay,self.augmentation_prob])
f.close()
`
@sarath0993 did you end up fixing this problem? I am currently dealing with the same issue however only for the R2U-Net
I have changed lines as mentioned in based on the solution given to the previous issues, But still I couldn't get result SR image. I got some strange input image and empty(full black) image in the results directory. But during training, I got series of good SR image. Can you please provide complete inference code to get SR image as an output?