Yanfeng-Zhou / XNet

[ICCV2023] XNet: Wavelet-Based Low and High Frequency Merging Networks for Semi- and Supervised Semantic Segmentation of Biomedical Images
MIT License
184 stars 8 forks source link

About loss_function #10

Closed Seconight closed 8 months ago

Seconight commented 10 months ago

Hi! Thank you for your great work! 🤗 I have a question when I train the 3D model on the Atrial dataset with the script train_sup_XNet3d.py. There was an exception: image Then I checked the corresponding tensors in loss_function.py. For forward(self, output, target) in class DiceLoss(nn.Module) :

def forward(self, output, target):
        # preds, target = tuple(inputs)
        # inputs = tuple(list(preds) + [target])
        if self.aux:
            return self._aux_forward(output, target)
        else:
            valid_mask = (target != self.ignore_index).long()
            target_one_hot = F.one_hot(torch.clamp_min(target, 0))

            return self._base_forward(output, target_one_hot, valid_mask)

The shape of output is torch.Size([1, 2, 96, 96, 80]), however the shape of target_one_hot is torch.Size([1, 96, 96, 80, 256]). There is only 0 and 255 in target, should the target be 1?

Yanfeng-Zhou commented 10 months ago

In order to facilitate the visualization of segmentation masks, the foreground in the dataset is generally marked as 255. During training, we should convert 255 to 1, that is, foreground = 1, background = 0.