Closed zhangjia3 closed 4 years ago
Hi zhangjia3
ASL is basically the same for multi and single label, it just uses softamx instead of sigmoid.
if focal-loss works better than cross entropy on your data, you can get a real improvement to the score from ASL. ASL works much better than focal-loss.
anyway, this was basically our single-label ASL variant, "background" in single-label is still the case where the object doesn't appear:
class ASLSingleLabel(Module):
def __init__(self, args, eps: float = 0.1, reduction='mean'):
self.eps = eps
self.logsoftmax = nn.LogSoftmax(dim=-1)
self.targets_classes = [] # prevent gpu repeated memory allocation
self.args = args
self.reduction = reduction
def forward(self, inputs, target, reduction=None):
num_classes = inputs.size()[-1]
log_preds = self.logsoftmax(inputs)
self.targets_classes = torch.zeros_like(inputs).scatter_(1, target.long().unsqueeze(1), 1)
# ASL weights
if self.args.gamma_neg is not None:
targets = self.targets_classes.clone()
anti_targets = 1 - targets
xs_pos = torch.exp(log_preds)
xs_neg = 1 - xs_pos
xs_pos = xs_pos * targets
xs_neg = xs_neg * anti_targets
asymmetric_w = torch.pow(1 - xs_pos - xs_neg,
self.args.gamma_pos * targets + self.args.gamma_neg * anti_targets)
log_preds = log_preds * asymmetric_w
if self.eps > 0:
self.targets_classes.mul_(1 - self.eps).add_(self.eps / num_classes)
# loss calculation
loss = - self.targets_classes.mul(log_preds)
loss = loss .sum(dim=-1)
if self.reduction == 'mean':
loss = loss .mean()
return loss
首先,非常感谢您的回复,同时关于参数设置方面还有一些疑问需要向您指教。论文附录里指出在运用在细粒度任务时,γ- = 4, γ+ = 0。对应代码里的self.args.gamma_neg=4,self.args.gamma_pos=0,self.args.gamma_focal_loss=None,这样设置参数是否正确?如不正确该如何设置?期待您的回复
for our runs we used, as you said: self.args.gamma_pos=0 self.args.gamma_neg=4 self.args.gamma_focal_loss=None
however, you can try other parameter regimes, the parameters we used necessarily they are optimal for any scenario or dataset. other "reasonable" values: self.args.gamma_pos=1 self.args.gamma_neg=4
self.args.gamma_pos=1 self.args.gamma_neg=2
all the best Tal
您好,非常感谢您以及您的团队在这方面的贡献,我准备将此损失函数运用于细粒度分类,请问如何将此损失函数此任务,是否可以公开train.py示例。