Closed bogdan-s closed 3 years ago
BinaryFocalLoss
is for binary tasks only, whereas tf.losses.SparseCategoricalCrossentropy
is for multiclass tasks. They have different assumptions about the structure of the labels and predictions.
If you want to adapt focal loss to the multiclass setting, you can start with something like the following (not tested, and not including features like label smoothing or class weighting):
class SparseCategoricalFocalLoss(tf.keras.losses.Loss):
def __init__(self, gamma: float, from_logits: bool = False, **kwargs):
super().__init__(**kwargs)
self.gamma = gamma
self.from_logits = from_logits
def get_config(self):
config = super().get_config()
config.update(gamma=self.gamma, from_logits=self.from_logits)
return config
def call(self, y_true, y_pred):
y_pred = tf.convert_to_tensor(y_pred)
y_true = tf.dtypes.cast(y_true, dtype=tf.dtypes.int32)
base_loss = tf.keras.backend.sparse_categorical_crossentropy(
target=y_true, output=y_pred, from_logits=self.from_logits)
if self.from_logits:
probs = tf.nn.softmax(y_pred, axis=-1)
else:
probs = y_pred
batch_size = tf.shape(y_true)[0]
indices = tf.stack([tf.range(batch_size), y_true], axis=1)
probs = tf.gather_nd(probs, indices)
focal_modulation = (1 - probs) ** self.gamma
return focal_modulation * base_loss
This should be interchangeable with tf.keras.losses.SparseCategoricalCrossentropy
(up to bugs!). While I didn't test it extensively, it looks like it might work:
>>> # From probabilities (e.g., softmax output)
>>> loss = SparseCategoricalFocalLoss(gamma=2)
>>> y_true = [0, 1, 2]
>>> y_pred = [[0.8, 0.1, 0.1], [0.1, 0.8, 0.1], [0.1, 0.1, 0.8]]
>>> loss(y_true, y_pred)
<tf.Tensor: shape=(), dtype=float32, numpy=0.008925741>
>>> # From logits
>>> loss = SparseCategoricalFocalLoss(gamma=2, from_logits=True)
>>> y_true = [0, 1, 2]
>>> y_pred = [[1.0, -1.0, -1.0], [-1.0, 1.0, -1.0], [-1.0, -1.0, 1.0]]
>>> loss(y_true, y_pred)
<tf.Tensor: shape=(), dtype=float32, numpy=0.010869335>
focal_loss.SparseCategoricalFocalLoss
is now available. It should work as a replacement for tf.keras.losses.SparseCategoricalCrossentropy
.
Thank you. I'll look into it.
Hello, I'm trying to replace the tf.losses.SparseCategoricalCrossentropy with your loss, but i think it doesn't accept sparse inputs. The model is a simple U-Net segmentation with a softmax end. Can you point me in the right direction on how to solve this? Thank you