clovaai / CutMix-PyTorch

Official Pytorch implementation of CutMix regularizer
MIT License
1.22k stars 159 forks source link

cutmix for segmentation #25

Closed fengweie closed 4 years ago

fengweie commented 4 years ago

I'm very interested in your work and recently I've been trying to use it in the segmentation problem and here I have a couple of questions, one, how does this formula: loss = criterion(output, target_a) lam + criterion(output, target_b) (1. -lam) in the paper expand into the segmentation problem?I use 3d image data, so I cutmix a volume. In addition, I am not clear about how to set some parameters in this case. Could you give me some Suggestions? here is my implementation :

        volume_batch_label = volume_batch[:labeled_bs].clone()
        label_batch_label = label_batch[:labeled_bs].clone()
        # r = np.random.rand(1)
        # if r < 0.5:
        # generate mixed sample(labeled_data)
        lam_label = np.random.beta(1,1)
        rand_index_label = torch.randperm(volume_batch_label.size()[0]).cuda()
        # target_a_label = label_batch_label
        # target_b_label = label_batch_label[rand_index_label]
        bbx1, bby1, bbz1, bbx2, bby2, bbz2 = rand_bbox(volume_batch_label.size(), lam_label)

        volume_batch_label[:, :, bbx1:bbx2, bby1:bby2, bbz1:bbz2] = volume_batch_label[rand_index_label, :, bbx1:bbx2, bby1:bby2, bbz1:bbz2]

        label_batch_label[:, bbx1:bbx2, bby1:bby2, bbz1:bbz2] = label_batch_label[rand_index_label, bbx1:bbx2, bby1:bby2, bbz1:bbz2]
        # adjust lambda to exactly match pixel ratio
        lam_label = 1 - ((bbx2 - bbx1) * (bby2 - bby1) * (bbz2 - bbz1) / (volume_batch[:labeled_bs].size()[-1] * volume_batch[:labeled_bs].size()[-2] * volume_batch[:labeled_bs].size()[-3]))
        # compute output
        outputs_label = model(volume_batch_label)
        ## serpervised loss labeleddata
        loss_seg = F.cross_entropy(outputs_label, label_batch_label)
        outputs_soft_label = F.softmax(outputs_label, dim=1)
        loss_seg_dice = losses.dice_loss(outputs_soft_label[:, 1, :, :, :], label_batch_label == 1)
hellbell commented 4 years ago

@Karles-ai

how does this formula: loss = criterion(output, target_a) lam + criterion(output, target_b) (1. -lam) in the paper expand into the segmentation problem?

This criterion, which mixes two one-hot labels, is designed for the classification task. For segmentation task, your label_batch_label tensor is already mixed by this operation,

label_batch_label[:, bbx1:bbx2, bby1:bby2, bbz1:bbz2] = \
label_batch_label[rand_index_label, bbx1:bbx2, bby1:bby2, bbz1:bbz2]

So the original criterion cannot directly be applied for this case. I guess this paper is somewhat related for cutmix and segmentation task.

I use 3d image data, so I cutmix a volume. In addition, I am not clear about how to set some parameters in this case. Could you give me some Suggestions?

You might have to try various types of CutMix on the 3d space. Perhaps some axis (ex: time-axis) would be critical for the performance, so preserving it would be important.