LeeJunHyun / Image_Segmentation

Pytorch implementation of U-Net, R2U-Net, Attention U-Net, and Attention R2U-Net.
2.66k stars 594 forks source link

Inference code #55

Closed sarathsrk closed 4 years ago

sarathsrk commented 4 years ago

I have changed lines as mentioned in based on the solution given to the previous issues, But still I couldn't get result SR image. I got some strange input image and empty(full black) image in the results directory. But during training, I got series of good SR image. Can you please provide complete inference code to get SR image as an output?

sarathsrk commented 4 years ago

Here is my inference code, It would be helpful if someone corrects me

` unet_path = os.path.join(self.model_path, 'U_Net-100-0.0003-14-0.4129.pkl')

    # U-Net Train
    if os.path.isfile(unet_path):
        # Load the pretrained Encoder
        self.unet.load_state_dict(torch.load(unet_path))
        print('%s is Successfully Loaded from %s'%(self.model_type,unet_path))
    #else:
        # Train for Encoder
        #lr = self.lr
        #best_unet_score = 0.

        #===================================== Test ====================================#
        #del self.unet
        #del best_unet
        print("test running")
        #self.unet.test()
        self.build_model()
        self.unet.load_state_dict(torch.load(unet_path))

        self.unet.train(False)
        self.unet.eval()

        acc = 0.    # Accuracy
        SE = 0.        # Sensitivity (Recall)
        SP = 0.        # Specificity
        PC = 0.     # Precision
        F1 = 0.        # F1 Score
        JS = 0.        # Jaccard Similarity
        DC = 0.        # Dice Coefficient
        length=0
        for i, (images, GT) in enumerate(self.test_loader):

            images = images.to(self.device)
            GT = GT.to(self.device)
            SR = self.unet(images)
            acc += get_accuracy(SR,GT)
            SE += get_sensitivity(SR,GT)
            SP += get_specificity(SR,GT)
            PC += get_precision(SR,GT)
            F1 += get_F1(SR,GT)
            JS += get_JS(SR,GT)
            DC += get_DC(SR,GT)

            length += images.size(0)
            print('ACC:{}, SE:{}, SP:{},PC:{},F1:{},JS:{},DC:{}'.format(acc,SE,SP,PC,F1,JS,DC))
            torchvision.utils.save_image(images.data.cpu(), os.path.join(self.result_path,'%s_valid_%d_image.png'))

            torchvision.utils.save_image(SR.data.cpu(),os.path.join(self.result_path,'%s_test_%d_SR.png'))

            torchvision.utils.save_image(GT.data.cpu(), os.path.join(self.result_path,'%s_test_%d_GT.png'))
            #print("result path = ", self.result_path)

        acc = acc/length
        SE = SE/length
        SP = SP/length
        PC = PC/length
        F1 = F1/length
        JS = JS/length
        DC = DC/length
        unet_score = JS + DC

        f = open(os.path.join(self.result_path,'result.csv'), 'a', encoding='utf-8', newline='')
        wr = csv.writer(f)
        wr.writerow([self.model_type,acc,SE,SP,PC,F1,JS,DC,self.lr,best_epoch,self.num_epochs,self.num_epochs_decay,self.augmentation_prob])
        f.close()

`

DeVriesMatt commented 4 years ago

@sarath0993 did you end up fixing this problem? I am currently dealing with the same issue however only for the R2U-Net