Alibaba-MIIL / ASL

Official Pytorch Implementation of: "Asymmetric Loss For Multi-Label Classification"(ICCV, 2021) paper
MIT License
732 stars 102 forks source link

ASL用于细粒度分类 #7

Closed zhangjia3 closed 4 years ago

zhangjia3 commented 4 years ago

您好,非常感谢您以及您的团队在这方面的贡献,我准备将此损失函数运用于细粒度分类,请问如何将此损失函数此任务,是否可以公开train.py示例。

mrT23 commented 4 years ago

Hi zhangjia3

ASL is basically the same for multi and single label, it just uses softamx instead of sigmoid.

if focal-loss works better than cross entropy on your data, you can get a real improvement to the score from ASL. ASL works much better than focal-loss.

anyway, this was basically our single-label ASL variant, "background" in single-label is still the case where the object doesn't appear:

class ASLSingleLabel(Module):

    def __init__(self, args, eps: float = 0.1, reduction='mean'):
        self.eps = eps
        self.logsoftmax = nn.LogSoftmax(dim=-1)
        self.targets_classes = []  # prevent gpu repeated memory allocation
        self.args = args
        self.reduction = reduction

    def forward(self, inputs, target, reduction=None):
        num_classes = inputs.size()[-1]
        log_preds = self.logsoftmax(inputs)
        self.targets_classes = torch.zeros_like(inputs).scatter_(1, target.long().unsqueeze(1), 1)

        # ASL weights
        if self.args.gamma_neg is not None:
            targets = self.targets_classes.clone()
            anti_targets = 1 - targets
            xs_pos = torch.exp(log_preds)
            xs_neg = 1 - xs_pos
            xs_pos = xs_pos * targets
            xs_neg = xs_neg * anti_targets
            asymmetric_w = torch.pow(1 - xs_pos - xs_neg,
                                     self.args.gamma_pos * targets + self.args.gamma_neg * anti_targets)
            log_preds = log_preds * asymmetric_w

        if self.eps > 0:
            self.targets_classes.mul_(1 - self.eps).add_(self.eps / num_classes)

        # loss calculation
        loss = - self.targets_classes.mul(log_preds)

        loss = loss .sum(dim=-1)
        if self.reduction == 'mean':  
            loss = loss .mean()

        return loss 
zhangjia3 commented 4 years ago

首先,非常感谢您的回复,同时关于参数设置方面还有一些疑问需要向您指教。论文附录里指出在运用在细粒度任务时,γ- = 4, γ+ = 0。对应代码里的self.args.gamma_neg=4,self.args.gamma_pos=0,self.args.gamma_focal_loss=None,这样设置参数是否正确?如不正确该如何设置?期待您的回复

mrT23 commented 4 years ago

for our runs we used, as you said: self.args.gamma_pos=0 self.args.gamma_neg=4 self.args.gamma_focal_loss=None

however, you can try other parameter regimes, the parameters we used necessarily they are optimal for any scenario or dataset. other "reasonable" values: self.args.gamma_pos=1 self.args.gamma_neg=4

self.args.gamma_pos=1 self.args.gamma_neg=2

all the best Tal