bodokaiser / piwise

Pixel-wise segmentation on VOC2012 dataset using pytorch.
BSD 3-Clause "New" or "Revised" License
383 stars 86 forks source link

Semantic segmentation for binary masks (2 class: background/foreground object) #10

Closed andrewssobral closed 7 years ago

andrewssobral commented 7 years ago

Hi @bodokaiser Do you know how to adapt this code for binary masks? (background = black, foreground = white) I changed the number of classes to 2:

NUM_CLASSES = 22

NUM_CLASSES = 2 But some erros still occurs:

Traceback (most recent call last):
  File "main2.py", line 164, in <module>
    main(parser.parse_args())
  File "main2.py", line 139, in main
    train(args, model)
  File "main2.py", line 74, in train
    loss = criterion(outputs, targets[:, 0])
  File "/home/ubuntu/src/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/nn/modules/module.py", line 224, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/ubuntu/pytorch/piwise/piwise/criterion.py", line 13, in forward
    return self.loss(F.log_softmax(outputs), targets)
  File "/home/ubuntu/src/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/nn/modules/module.py", line 224, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/ubuntu/src/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/nn/modules/loss.py", line 132, in forward
    self.ignore_index)
  File "/home/ubuntu/src/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/nn/functional.py", line 674, in nll_loss
    return _functions.thnn.NLLLoss2d.apply(input, target, weight, size_average, ignore_index)
  File "/home/ubuntu/src/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/nn/_functions/thnn/auto.py", line 47, in forward
    output, *ctx.additional_args)
RuntimeError: weight tensor should be defined either for all or no classes at /pytorch/torch/lib/THCUNN/generic/SpatialClassNLLCriterion.cu:28
(pytorch) ubuntu@ip-172-31-85-122:~/pytorch/piwise$
bodokaiser commented 7 years ago

Hi Andrew,

Can you change in main.py

if args.cuda:
    criterion = CrossEntropyLoss2d(weight.cuda())
else:
    criterion = CrossEntropyLoss2d(weight)

to

criterion = CrossEntropyLoss2d()

also you might want to read this for binary segmentation.

andrewssobral commented 7 years ago

Thank you @bodokaiser !

andrewssobral commented 7 years ago

Hi @bodokaiser again, Just a small question, why the number of class is defined as 22 if PascalVoc has 20 classes?

bodokaiser commented 7 years ago

I think VOC before 2012 had less classes, however according to segmentation examples

pixel indices correspond to classes in alphabetical order (1=aeroplane, 2=bicycle, 3=bird, 4=boat, 5=bottle, 6=bus, 7=car , 8=cat, 9=chair, 10=cow, 11=diningtable, 12=dog, 13=horse, 14=motorbike, 15=person, 16=potted plant, 17=sheep, 18=sofa, 19=train, 20=tv/monitor) For both types of segmentation image, index 0 corresponds to background and index 255 corresponds to 'void' or unlabelled.

andrewssobral commented 7 years ago

Thank you @bodokaiser ! So, if I have only 2 classes on my dataset (0 = background, 255 = foreground), I need to set NUM_CLASSES = 2 ?

bodokaiser commented 7 years ago

NUM_CLASSES = 2 basically just says how much output channels to use in the last layer(s) of the chosen network architecture.

There is also another VOC specific transform which convert the color codes of the VOC images to class labels numbered from 1 to 22 so you might want to change this according to your dataset.

andrewssobral commented 7 years ago

Solved!

numancelik34 commented 4 years ago

I am trying to solve a binary mask segmentation as well for my dataset in this VOC format.. however getting a NaN value for segmentation loss.. could you please help me here?? Thanks!

andrewssobral commented 4 years ago

Hello @numancelik34 , I am sorry for the late reply, and thank you for the contact! Yes, my solution for this issue can be found here: https://github.com/andrewssobral/deep-learning-pytorch/tree/master/segmentation I created a git repository with some codes showing how to do binary segmentation with pytorch. Please, let me know if it helps you, and feel free to contact me if you have any questions. Best regards, Andrews