232525 / PureT

Implementation of 'End-to-End Transformer Based Model for Image Captioning' [AAAI 2022]
63 stars 12 forks source link

xe_criterion #23

Closed concrettojava closed 10 months ago

concrettojava commented 10 months ago

如果用label smoothing,损失实际上是用KL散度计算的,这也是一种交叉熵吗?

232525 commented 10 months ago

这部分的代码直接是copy自另一篇论文的repo,如果需要按照定义计算,可以改成如下:

class LabelSmoothing(nn.Module):
    def __init__(self):
        super(LabelSmoothing, self).__init__()
        self.smoothing = cfg.LOSSES.LABELSMOOTHING
        self.criterion = nn.NLLLoss(ignore_index=-1, reduction="none")

    def forward(self, logit, target_seq, reduction="mean"):
        logit = logit.view(-1, logit.shape[-1])
        target_seq = target_seq.view(-1)
        mask = target_seq >= 0
        # 正常XE loss
        xe_loss = self.criterion(logit, target_seq)
        # log概率的直接求和 [B]
        smooth_loss = -logit.sum(-1)
        if reduction == "mean":
            xe_loss = torch.masked_select(xe_loss, mask).mean()
            smooth_loss = torch.masked_select(smooth_loss, mask).mean()
        elif reduction == "sum":
            xe_loss = torch.masked_select(xe_loss, mask).sum()
            smooth_loss = torch.masked_select(smooth_loss, mask).sum()
        elif reduction == "none":
            xe_loss = xe_loss.mean()
            smooth_loss = smooth_loss.mean()
        else:
            raise Exception("unsupported reduction {}".format(reduction))
        # label smoothing
        loss = (1 - self.smoothing) * xe_loss + self.smoothing * smooth_loss / logit.size()[1]
        return loss, {'LabelSmoothing Loss': loss.item()}
concrettojava commented 10 months ago

非常感谢作者大大!😍