mlyg / unified-focal-loss

Apache License 2.0
150 stars 22 forks source link

There is error in the implementation of the unified_focal_loss #3

Closed MohamedAliRashad closed 3 years ago

MohamedAliRashad commented 3 years ago

Documentation doesn't represent the parameters and asymmetric_focal_loss is taking wrong attributes

def unified_focal_loss(weight=0.5, delta=0.6, gamma=0.2):
    """
    :param weight: represents lambda parameter and controls weight given to Asymmetric Focal Tversky loss 
                   and Asymmetric Focal loss
    :param alpha: controls weight given to each class
    :param beta: controls relative weight of false positives and false negatives. Beta > 0.5 penalizes 
                  false negatives more than false positives.

    :param gamma: focal parameter controls the degree of background suppression and foreground enhancement
    """
    def loss_function(y_true,y_pred):
      # Obtain Asymmetric Focal Tversky loss
      asymmetric_ftl = asymmetric_focal_tversky_loss(delta=delta, gamma=gamma)(y_true,y_pred)
      # Obtain Asymmetric Focal loss
      asymmetric_fl = asymmetric_focal_loss(delta=delta, gamma=gamma)(y_true,y_pred)
      # return weighted sum of Asymmetrical Focal loss and Asymmetric Focal Tversky loss
      if weight is not None:
        return (weight * asymmetric_ftl) + ((1-weight) * asymmetric_fl)  
      else:
        return asymmetric_ftl + asymmetric_fl

    return loss_function
mlyg commented 3 years ago

Thank you very much for spotting that MohamedAliRashad!

Indeed I forgot to update the asymmetric Focal loss as I updated this repository. The correct version should be up now.