CasiaFan / tensorflow_retinanet

RetinaNet with Focal Loss implemented by Tensorflow
121 stars 45 forks source link

Focal Loss weighting #14

Open MikeyLev opened 5 years ago

MikeyLev commented 5 years ago

Hi @CasiaFan , Can you explain the weighting you apply on the focal loss ? weighted_loss = ce * tf.pow(1-predictions_pt, gamma) * alpha_t * tf.expand_dims(weights, axis=2) Is it something in original paper? what are those weights?

CasiaFan commented 5 years ago

@MikeyLev The first half part is the definition from Part 3.2 in RetinaNet paper, while the rest part is from Part 4.1 Focal loss for calculating the total focal loss of an image:

The total focal loss of an image is computed as the sum of the focal loss over all ∼100k anchors, normalized by the number of anchors assigned to a ground-truth box.

So the weights is the normalizer obtained from anchor assignment. The reason for putting it here is just for code simplicity. For a better understanding, maybe you could put it outside where the loss is calculated and used.