tusen-ai / RangeDet

Paper and Codes for “RangeDet: In Defense of Range View for LiDAR-based 3D Object Detection” (ICCV2021)
Apache License 2.0
188 stars 16 forks source link

About the varifocal loss function in your implementation #12

Closed zhby99 closed 2 years ago

zhby99 commented 2 years ago

the varifocal loss implemented in https://github.com/TuSimple/RangeDet/blob/main/rangedet/symbol/head/loss.py seems quite different with the formula given in the paper(which is the normal varifocal loss in my opinion), will the implementation here give better results than the normal varifocal loss?

Specifically, I dont understand the forward term 2 very well, especially the minus_log part.

minus_log = minus_logits_mask - log_one_exp_minus_abs

where log_one_exp_minus_abs = log(1. +exp(-abs(logits))) and minus_logits_mask=-logits if logits >=0 else 0

Abyssaledge commented 2 years ago

You can use the following one, which is more numerical stable and clear:

def vari_focal_loss_stable(logit, label, loss_scale=1, alpha=1.,gamma=2.0):
    '''
    loss scale is not used.
    '''
    bce = X.relu(logit) - logit * label
    bce = bce + mx.sym.Activation(
        -mx.sym.abs(logit),
        act_type='softrelu'
    )
    pred_score = mx.sym.sigmoid(logit)
    positive_mask = mx.sym.broadcast_greater(lhs=label, rhs=mx.sym.zeros((1, 1)))
    loss_positive = bce * label * positive_mask
    negative_mask = mx.sym.broadcast_equal(lhs=label, rhs=mx.sym.zeros((1, 1)))
    loss_negative = bce * alpha * mx.sym.power(pred_score, gamma) * negative_mask
    loss = loss_negative + loss_positive
    return loss