bigmb / Unet-Segmentation-Pytorch-Nest-of-Unets

Implementation of different kinds of Unet Models for Image Segmentation - Unet , RCNN-Unet, Attention Unet, RCNN-Attention Unet, Nested Unet
MIT License
1.87k stars 345 forks source link

how to assure the identical transform on both image and labels #29

Closed ppjerry closed 4 years ago

ppjerry commented 4 years ago

Hi, Thank for this great project. I have a question about the data augmentation. In Image_Dataset_folder class, the transforms of the image and label are separate. How to assure the same random transform of images could be used in the related labels?

bigmb commented 4 years ago

Hello,

The way they are saved and called in the project is my way to make sure that same label and image are called. I have taken a '0000' formar to make sure that happens after transformation.

ppjerry commented 4 years ago

'0000', what is that?

bigmb commented 4 years ago

Sorry I forgot to add that in the code on how i save the files. Instead of directly just increasing the value of 'i'. I saved it in a proper format of '0001'-'0002' , so the order of the images didnt change. I made that mistake initially and I had to save it in this format to avoid all this mixing of labels and images. If you want that code, let me know, I will have to dig up some old files for that.

ppjerry commented 4 years ago

Hi, In the following code, the self.tx worked on the image and the self.lx worked on the label. There is a randomRotation transformation in self.tx but there is no same randomrotation in the self.lx. Does it mean that the image will have the random rotation transformation but its label will not be rotated?

    if self.transformI:
        self.tx = self.transformI
    else:
        self.tx = torchvision.transforms.Compose([
          #  torchvision.transforms.Resize((128,128)),
            torchvision.transforms.CenterCrop(96),
            torchvision.transforms.RandomRotation((-10,10)),
           # torchvision.transforms.RandomHorizontalFlip(),
            torchvision.transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])

    if self.transformM:
        self.lx = self.transformM
    else:
        self.lx = torchvision.transforms.Compose([
          #  torchvision.transforms.Resize((128,128)),
            torchvision.transforms.CenterCrop(96),
            torchvision.transforms.Grayscale(),
            torchvision.transforms.ToTensor(),
            #torchvision.transforms.Lambda(lambda x: torch.cat([x, 1 - x], dim=0))
        ])

def __len__(self):

    return len(self.images)

def __getitem__(self, i):
    i1 = Image.open(self.images_dir + self.images[i])
    l1 = Image.open(self.labels_dir + self.labels[i])

    return self.tx(i1), self.lx(l1)
bigmb commented 4 years ago

I just checked that out. And yes that's a mistake from my end.

I will correct it in some days, but if you do, send me a pull request. Check out this issue for the solution:[https://github.com/pytorch/vision/issues/9]

ppjerry commented 4 years ago

I just had a pull request try

bigmb commented 4 years ago

Merged.