umbertogriffo / focal-loss-keras

Binary and Categorical Focal loss implementation in Keras.
278 stars 67 forks source link

A cleaner pattern to make custom_objects simplier #7

Open isaacgerg opened 5 years ago

isaacgerg commented 5 years ago
class categorical_focal_loss:                             
    '''
    Softmax version of focal loss.

           m
      FL = sum  -alpha * (1 - p_o,c)^gamma * y_o,c * log(p_o,c)
          c=1

      where m = number of classes, c = class and o = observation

    Parameters:
      alpha -- the same as weighing factor in balanced cross entropy
      gamma -- focusing parameter for modulating factor (1-p)

    Default value:
      gamma -- 2.0 as mentioned in the paper
      alpha -- 0.25 as mentioned in the paper

    References:
        Official paper: https://arxiv.org/pdf/1708.02002.pdf
        https://www.tensorflow.org/api_docs/python/tf/keras/backend/categorical_crossentropy

    Usage:
     model.compile(loss=[categorical_focal_loss(alpha=.25, gamma=2)], metrics=["accuracy"], optimizer=adam)
    '''
    def __init__(self, gamma=2., alpha=.25):
        self._gamma = gamma
        self._alpha = alpha
        self.__name__ = 'categorical_focal_loss'

    def __int_shape(self, x):
        return tf.keras.backend.int_shape(x) if self.backend == 'tensorflow' else tf.keras.backend.shape(x)    

    def  __call__(self, y_true, y_pred):        
        '''
        :param y_true: A tensor of the same shape as `y_pred`
        :param y_pred: A tensor resulting from a softmax
        :return: Output tensor.
        '''

        # Scale predictions so that the class probas of each sample sum to 1
        y_pred /= tf.keras.backend.sum(y_pred, axis=-1, keepdims=True)

        # Clip the prediction value to prevent NaN's and Inf's
        epsilon = tf.keras.backend.epsilon()
        y_pred = tf.keras.backend.clip(y_pred, epsilon, 1. - epsilon)

        # Calculate Cross Entropy
        cross_entropy = -y_true * tf.keras.backend.log(y_pred)

        # Calculate Focal Loss
        loss = self._alpha * tf.keras.backend.pow(1 - y_pred, self._gamma) * cross_entropy

        # Sum the losses in mini_batch
        return tf.keras.backend.sum(loss, axis=1)

With this pattern, I don't need dill when using load_model.

umbertogriffo commented 4 years ago

Hi, @isaacgerg thanks for the suggestion. Could you open a PR?

isaacgerg commented 4 years ago

Unfortunately, I cannot due to my network settings.

dmonkoff commented 3 years ago

Hey, I've seen this implementation in a lot of projects and I feel that it is not right to set alpha parameter that way. Alpha as scalar makes sense in binary example, since it's a weight of the positive samples i.e loss = - gt alpha ((1 - pr)^gamma) log(pr) - (1 - gt) (1-alpha) (pr^gamma) log(1 - pr). Positive samples are weighted by alpha, negative samples are weighted by 1-alpha, all good. If we look at multiple classes, there is no distinguished 'negative' class, you output a softmax vector of class probabilities. So, if we go by the same logic as in binary case, each class should be weighted separately. If you use one value, you basically just scaling the loss by that factor and that's all, no class weighting is done here. I might be missing something here, so please correct me if I'm wrong.

Edit: I've checked the function more thoughtfully and you can use vector set alphas, so that was my misunderstanding. I would still add a better description of alpha in the function, and change default alpha value to something more appropriate, for example alpha=1, so all classes are just weighted equally, no scaling is done. alpha=0.25 doesn't make much sense for multiclass example

umbertogriffo commented 3 years ago

Hey @dmonkoff, You are right, this is an older implementation. The issue has been already solved here, so as I've already written in the latest version you need to specify α as an array, and the size of the array needs to be consistent with the number of categories, representing the corresponding weight of each category.

I'll take into account your suggestions, thank you very much!