bigmb / Unet-Segmentation-Pytorch-Nest-of-Unets

Implementation of different kinds of Unet Models for Image Segmentation - Unet , RCNN-Unet, Attention Unet, RCNN-Attention Unet, Nested Unet
MIT License
1.87k stars 345 forks source link

question with 'data_transform' #32

Closed iWeisskohl closed 4 years ago

iWeisskohl commented 4 years ago

Hi, I have some doubts with 'data_transform' function. As you suggested , input image should be 3 channel image and input label should be 1 channel image , but I find you use the same data_transform function

data_transform = torchvision.transforms.Compose([

torchvision.transforms.Resize((128,128)),

     #   torchvision.transforms.CenterCrop(96),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])

for input image and input label during training. And use another function

data_transform = torchvision.transforms.Compose([

torchvision.transforms.Resize((128,128)),

    #    torchvision.transforms.CenterCrop(96),
         torchvision.transforms.Grayscale(), 
    #torchvision.transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  ])

for input image and input label for calculating the Dice Score.

and the codes output error with shape error with those functions when i run them . So I am wondering is there a misktake with the defination and using for data_transform function ?
Thanks in advance ! Have a nice day!

bigmb commented 4 years ago

Whata the error? Its different for input image and label. You can find the data transform in 3d_to_2d.py code

iWeisskohl commented 4 years ago

yes, I use the 3d_to_2d.py code to get data and run pytorch_run.py . the error is like : RuntimeError: output with shape [1, 128, 128] doesn't match the broadcast shape [3, 128, 128]. reported in pytorch_run.py line 311: line310 s_tb = data_transform(im_tb) line311 s_label = data_transform(im_label), becuase i find you use the same transform function for diffeeren channel input, which definited as : data_transform = torchvision.transforms.Compose([

torchvision.transforms.Resize((128,128)),

     #   torchvision.transform, CenterCrop(96),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ]) 

so I change the data_trandform definition for input label(because it is single channel, and can't be normalized with three channel setting), and the problem solved. Actually ,there is no data transform in 2d_from_3d.py, but have transforms.Compose in data_loader.py. So I am wondering if you have done data transform when loading datas, deose it necessary to re-transform data again in pytorch_run.py (such as line 310 and line 311)?

sorry to disturbe you again, I really appreciate your nice work. Waiting for you reply.

bigmb commented 4 years ago

I did data transformation in a different jupyter file because I had to test different configurations. But yes the data_tranform for input images and labels should be different as 1 is 3 channel and labels are 1 channel.

Are you facing any errors now? And if there is any change, just send me a pull request.

c-arthurs commented 4 years ago

Hi, I am having the same problem. Namely that the data transform for the mask in line 295 uses the same transform as the 3 channel image data.

s_label = data_transform(im_label) giving the below error: RuntimeError: output with shape [1, 1020, 1020] doesn't match the broadcast shape [3, 1020, 1020]

Could you let me know what was changed to get it to run in your case?

c-arthurs commented 4 years ago

Just to answer my own question -

I added a new function for the binary mask transforms:

data_transform_mask = torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(mean=[0.5], std=[0.5])])

and applied this to line 295:

s_label = data_transform_mask(im_label)
bigmb commented 4 years ago

Done. Let me know if you need some help.

iWeisskohl commented 4 years ago

Hi, I just change the data_transform function for label image with one channel definition (similar with what @c-arthurs did. but I am not sure if the change is reasonable because I am not sure whether do data transform for ground truth image will influent segmentation results .

bigmb commented 4 years ago

Hey @iWeisskohl ,

Do you have a link supporting this? I found this is link supporting the @c-arthurs : https://github.com/pytorch/vision/issues/9 https://discuss.pytorch.org/t/torchvision-transfors-how-to-perform-identical-transform-on-both-image-and-target/10606/17