artemmavrin / focal-loss

TensorFlow implementation of focal loss
https://focal-loss.readthedocs.io
Apache License 2.0
186 stars 43 forks source link

Sparse tensor input #2

Closed bogdan-s closed 3 years ago

bogdan-s commented 4 years ago

Hello, I'm trying to replace the tf.losses.SparseCategoricalCrossentropy with your loss, but i think it doesn't accept sparse inputs. The model is a simple U-Net segmentation with a softmax end. Can you point me in the right direction on how to solve this? Thank you

artemmavrin commented 4 years ago

BinaryFocalLoss is for binary tasks only, whereas tf.losses.SparseCategoricalCrossentropy is for multiclass tasks. They have different assumptions about the structure of the labels and predictions.

If you want to adapt focal loss to the multiclass setting, you can start with something like the following (not tested, and not including features like label smoothing or class weighting):

class SparseCategoricalFocalLoss(tf.keras.losses.Loss):
    def __init__(self, gamma: float, from_logits: bool = False, **kwargs):
        super().__init__(**kwargs)
        self.gamma = gamma
        self.from_logits = from_logits

    def get_config(self):
        config = super().get_config()
        config.update(gamma=self.gamma, from_logits=self.from_logits)
        return config

    def call(self, y_true, y_pred):
        y_pred = tf.convert_to_tensor(y_pred)
        y_true = tf.dtypes.cast(y_true, dtype=tf.dtypes.int32)
        base_loss = tf.keras.backend.sparse_categorical_crossentropy(
            target=y_true, output=y_pred, from_logits=self.from_logits)

        if self.from_logits:
            probs = tf.nn.softmax(y_pred, axis=-1)
        else:
            probs = y_pred
        batch_size = tf.shape(y_true)[0]
        indices = tf.stack([tf.range(batch_size), y_true], axis=1)
        probs = tf.gather_nd(probs, indices)
        focal_modulation = (1 - probs) ** self.gamma

        return focal_modulation * base_loss

This should be interchangeable with tf.keras.losses.SparseCategoricalCrossentropy (up to bugs!). While I didn't test it extensively, it looks like it might work:

>>> # From probabilities (e.g., softmax output)
>>> loss = SparseCategoricalFocalLoss(gamma=2)
>>> y_true = [0, 1, 2]
>>> y_pred = [[0.8, 0.1, 0.1], [0.1, 0.8, 0.1], [0.1, 0.1, 0.8]]
>>> loss(y_true, y_pred)
<tf.Tensor: shape=(), dtype=float32, numpy=0.008925741>
>>> # From logits
>>> loss = SparseCategoricalFocalLoss(gamma=2, from_logits=True)
>>> y_true = [0, 1, 2]
>>> y_pred = [[1.0, -1.0, -1.0], [-1.0, 1.0, -1.0], [-1.0, -1.0, 1.0]]
>>> loss(y_true, y_pred)
<tf.Tensor: shape=(), dtype=float32, numpy=0.010869335>
artemmavrin commented 4 years ago

focal_loss.SparseCategoricalFocalLoss is now available. It should work as a replacement for tf.keras.losses.SparseCategoricalCrossentropy.

bogdan-s commented 4 years ago

Thank you. I'll look into it.