Alibaba-MIIL / PartialLabelingCSL

Official implementation for the paper: "Multi-label Classification with Partial Annotations using Class-aware Selective Loss"
MIT License
129 stars 18 forks source link

Soft labels #2

Closed wenh06 closed 2 years ago

wenh06 commented 2 years ago

The correctness (or concordance with the equations in your paper) of the one_side_w (and also asymmetric_w below) relies on the fact that your labels (or targets y) are hard labels (0 or 1). However, when one uses soft labels (using label smoothing), the concordance would fail.

I think the following code works for both hard labels (identical) and soft labels (not considering efficiency)

    def forward(self, x, y):
        """"
        Parameters
        ----------
        x: input logits
        y: targets (multi-label binarized vector)
        """

        # Calculating Probabilities
        x_sigmoid = torch.sigmoid(x)
        xs_pos = x_sigmoid
        xs_neg = 1 - x_sigmoid

        # Asymmetric Clipping
        if self.clip is not None and self.clip > 0:
            xs_neg = (xs_neg + self.clip).clamp(max=1)

        # Basic CE calculation
        los_pos = y*torch.log(xs_pos.clamp(min=self.eps))
        los_neg = (1-y)*torch.log(xs_neg.clamp(min=self.eps))
        # loss = los_pos + los_neg

        # Asymmetric Focusing
        if self.gamma_neg > 0 or self.gamma_pos > 0:
            if self.disable_torch_grad_focal_loss:
                prev = torch.is_grad_enabled()
                torch.set_grad_enabled(False)
            los_pos *= torch.pow(1-xs_pos, self.gamma_pos)
            los_neg *= torch.pow(xs_pos, self.gamma_neg)
            if self.disable_torch_grad_focal_loss:
                torch.set_grad_enabled(prev)
        loss = los_pos + los_neg

        return -loss.sum()
ebenbaruch commented 2 years ago

@wenh06 Thanks for the comment. Indeed, in this work, we basically assumed hard labels.

wenh06 commented 2 years ago

@wenh06 Thanks for the comment. Indeed, in this work, we basically assumed hard labels.

OK, thank you!