bigmb / Unet-Segmentation-Pytorch-Nest-of-Unets

Implementation of different kinds of Unet Models for Image Segmentation - Unet , RCNN-Unet, Attention Unet, RCNN-Attention Unet, Nested Unet
MIT License
1.87k stars 345 forks source link

about the iou of results #33

Closed elk-april closed 4 years ago

elk-april commented 4 years ago

hello, I use my own data to train the model, however, there is little that can be done to improve the iou by tuning lr (currently the highest accuracy is 60%,lr=0.1/0.05/0.01, weight_decay=1e-8). Is there any way to improve the performance? looking forward to your reply!

elk-april commented 4 years ago

I modified the structure of your code(divide to train/test) and delete some codes,like below: train.py

for x, y in tqdm(train_loader):
        x, y = x.to(device), y.to(device)

        input_images(x, y, i, n_iter, k)

        opt.zero_grad()

        y_pred = model_test(x)
        lossT = calc_loss(y_pred, y)     # Dice_loss Used
        train_loss += lossT.item() * x.size(0)
        lossT.backward()
        opt.step()
        x_size = lossT.item() * x.size(0)
        k = 2

    model_test.eval()
    torch.no_grad()

    for x1, y1 in tqdm(valid_loader):
        x1, y1 = x1.to(device), y1.to(device)

        y_pred1 = model_test(x1)
        lossL = calc_loss(y_pred1, y1)     # Dice_loss Used

        valid_loss += lossL.item() * x1.size(0)
        x_size1 = lossL.item() * x1.size(0)

    # print loss
    train_loss = train_loss / len(train_idx)
    valid_loss = valid_loss / len(valid_idx)

    if (i + 1) % 1 == 0:
        print('Validation loss decreased ({:.6f} --> {:.6f}).  Saving model '.format(valid_loss_min, valid_loss))
    torch.save(model_test.state_dict(),read_model_path+str(i+1)+".pth")
    loss_num.append(str(valid_loss)[0:5])

    print(read_model_path+str(i+1)+".pth  saved")
test.py
`for _ in tqdm(read_test_folder):

    # print(x_sort_test[i])
    im = Image.open(x_sort_test[i])
    im1 = im
    im_n = np.array(im1)
    im_n_flat = im_n.reshape(-1, 1)

    for j in range(im_n_flat.shape[0]):
        if im_n_flat[j] != 0:
            im_n_flat[j] = 255

    s = data_transform(im)
    pred = model_test(s.unsqueeze(0).cuda()).cpu()

    pred = F.sigmoid(pred)
    pred = pred.detach().numpy()
    pred = pred[0][0]

    img4 = (pred - pred.min()) * ((255 - 0) / (pred.max() - pred.min())) + 0
    img = img4.astype(np.uint8)
    threshold_predictions_v(img, args.thr)
    cv2.imwrite(res + x_sort_test[i].split('/')[-1], img)

    i = i + 1

imou = func.compute_mIoU(res, test_folderL)
print("checkpoints {} : miou {}".format(args.num, imou))`
bigmb commented 4 years ago

Hello, So what is the size of your dataset and whats the segmentation for? And if it is very less than the SOTA , then there should be a problem.

elk-april commented 4 years ago

Thanks for your reply! The size of my dataset is (1,384,416), Then I chanslate them to (3,384,416) by this code:

def expand(i_dir,o_dir):
    for name in os.listdir(i_dir):
        print(name)
        image_path = os.path.join(i_dir,name)
        img = cv2.imread(image_path,-1)
        image = np.expand_dims(img, axis=2)
        image = np.concatenate((image, image, image), axis=-1)
        base_name = name.split('.')[0]
        #print(base_name)
        cv2.imwrite(o_dir+base_name+'.png',image)

And the dataset is segmentation for lesion of the brain. I will show some examples for you: Orignal img: 00001

Chanslated img: 00001

Label : 00001

My result: 00001

bigmb commented 4 years ago

All the images are black? Screenshot from 2020-03-13 10-21-27

elk-april commented 4 years ago

yes Orignal img: dyepe=uint16 ,like below image label : dyepe=uint8 and,beckgroud =0,lesion=1 ,like below image

And I decide to rewrite the dataloer like this :

    def __getitem__(self, i):

        i1 = cv2.imread(self.images_dir + self.images[i],-1)
        l1 = cv2.imread(self.labels_dir + self.labels[i],-1)

        seed = np.random.randint(0, 2 ** 32)  # make a seed with numpy generator

        # apply this seed to img tranfsorms
        random.seed(seed)
        torch.manual_seed(seed)

        # apply this seed to target/label tranfsorms
        random.seed(seed)
        torch.manual_seed(seed)

        img = torch.Tensor(i1/1.0)
        label = torch.Tensor(l1/1.0)
        img = img.unsqueeze(0)
        label = label.unsqueeze(0)
        #print(img.max())
        #print()
        return img, label
elk-april commented 4 years ago

If you need the data,I can send a pail of data to you by email.

bigmb commented 4 years ago

That's fine. Did you check how the images are bring trained? Is the attention given to the right place and the activation at correct places?

elk-april commented 4 years ago

Sorry,i can't understand, would you like to explain what you said...... I have tried to remove the torchvision.transforms.CenterCrop(96), But it does not work.

And when I check the code,I found these :(in the pytorch_run.py line482)

        for j in range(im_n_flat.shape[0]):
            if im_n_flat[j] != 0:
                im_n_flat[j] = 255

The labels of my dataset is between 0 to 1, maybe I can set " im_n_flat[j] = 1"?

elk-april commented 4 years ago

Thank you very much! I may find my wrong.I ues new dataloder.Now the "lesion" iou reaches 0.70(UNet) Now I can continue to try another model. Thanks for your reply!