blackfeather-wang / ISDA-for-Deep-Networks

An efficient implicit semantic augmentation method, complementary to existing non-semantic techniques.
582 stars 93 forks source link

isda loss segmentation code #10

Closed sumenpuyuan closed 3 years ago

sumenpuyuan commented 3 years ago

https://github.com/blackfeather-wang/ISDA-for-Deep-Networks/blob/master/Semantic%20segmentation%20on%20Cityscapes/train_isda.py#L180

labels = ((1 - label_mask).mul(labels) + label_mask * 19).long()

19 is num_class?

blackfeather-wang commented 3 years ago

Yes. This code transfers the label "255" to "19".

sumenpuyuan commented 3 years ago

` label_mask = (labels == 255).long()

labels = ((1 - label_mask).mul(labels) + label_mask * C).long()

onehot = torch.zeros(N, C).cuda()

NxCxA_onehot = onehot.view(N, C, 1).expand(N, C, A) `

255 to 19? this can not make error? this make the label range [0, 19], ti says our num_class is 20.so make one_hot will make error

blackfeather-wang commented 3 years ago

Thank you for the comment, but we have to point out that, this does not make an error!

Since the label 255 should be ignored in segmentation (and this is automatically realized by the CrossEntropyLoss() ), here we simply assume that it is 20th class, such that the off-the-shelf code (e.g. EstimateCV) can be directly used. Otherwise, the indexing function will report an error if we have both [0,18] and 255 in labels.

I agree that there may still be some space to improve the implementation. However, for one thing, current code is correct. For another, we just observe a minimal additional time consumption (~5%) currently.

blackfeather-wang commented 3 years ago

Another important point is that, this transfer is just used in ISDAloss, and does not exist when computing the cross-entropy loss.

sumenpuyuan commented 3 years ago

Thank you for your replay,I understand what you say.But i dont not think your current code is correct, As I said above, index out of range.

label_mask = (labels == 255).long()
labels = ((1 - label_mask).mul(labels) + label_mask * C).long() # now it range [0,19]
onehot = torch.zeros(N, C).cuda() # C is 19
onehot.scatter_(1, labels.view(-1, 1), 1) # there should have error
NxCxA_onehot = onehot.view(N, C, 1).expand(N, C, A)
blackfeather-wang commented 3 years ago

Thank you for the discussion. C is 20, actually. Please see this code: https://github.com/blackfeather-wang/ISDA-for-Deep-Networks/blob/318c30976d0c412a7dd10250b0164beac6d4fbeb/Semantic%20segmentation%20on%20Cityscapes/train_isda.py#L221.

I believe that, the best way to check if there is a bug is to run the code in practice. Maybe you can try it.

sumenpuyuan commented 3 years ago

thanks, i try to run your code,but environmen pytorch-segmentation-toolbox is failed,thank your replay again :smile: