ethz-asl / background_foreground_segmentation

12 stars 2 forks source link

Use of masked cross-entropy loss #25

Open francescomilano172 opened 3 years ago

francescomilano172 commented 3 years ago

For reference, I have found two main ways in our code in which masked cross-entropy loss is used, and wanted to point out some subtle issues that may cause it to not evaluate it to what we would like, due to the way in which tf.keras.losses.SparseCategoricalCrossentropy is defined. Let's consider for example a batch with 8 image samples of size 200 x 300 in which there are in total 100000 non-valid pixels out of the 8 * 200 * 300 = 480000 total pixels.

Small unit test to double check the things above:

import numpy as np
import tensorflow as tf
from tensorflow.keras.losses import SparseCategoricalCrossentropy

def almost_equal(a, b):
  return tf.math.abs(a - b) < 1e-5

# Quick test:
# - Batch size = 2 containing 2 identical images of shape 2 x 3
# - 2 classes.
# - The network predicts the same unnormalized logits for the 2 images:
#   - Class 0:
#       0.5    0.3    0.2
#       0.4    0.7    0.7
#   - Class 1:
#       0.9    0.7    0.6
#       0.5    0.2    0.5
# - Ground truth labels (same for both images):
#         1      1      1
#         1      0      0
# - Mask:
#   - Sample 0:
#      True   False  True
#      True    True  True
#   - Sample 1:
#      True    True False
#      True   False  True
pred_y = tf.constant([[[[0.5, 0.9], [0.3, 0.7], [0.2, 0.6]],
                       [[0.4, 0.5], [0.7, 0.2], [0.7, 0.5]]]])
pred_y = tf.repeat(pred_y, 2, axis=0)
labels = tf.constant([[[1, 1, 1], [1, 0, 0]]])
labels = tf.repeat(labels, 2, axis=0)
# - Mask out pixel (0, 1) in sample 0 and pixels (0, 2) and (1, 1) in sample 1.
mask = np.ones(labels.shape, dtype=bool)
mask[0, 0, 1] = False
mask[1, 0, 2] = False
mask[1, 1, 1] = False
mask = tf.constant(mask)
# - Define losses.
loss_cse = SparseCategoricalCrossentropy(
    from_logits=False, reduction=tf.keras.losses.Reduction.SUM)
loss_cse_over_batches = SparseCategoricalCrossentropy(
    from_logits=False, reduction=tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE)
loss_cse_with_logits = SparseCategoricalCrossentropy(
    from_logits=True, reduction=tf.keras.losses.Reduction.SUM)
loss_cse_with_logits_over_batches = SparseCategoricalCrossentropy(
    from_logits=True, reduction=tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE)
# - Apply softmax on predictions.
pred_y_sm = tf.nn.softmax(pred_y)

# Unit tests.
log = tf.math.log
total_num_pixels = labels.shape[0] * labels.shape[1] * labels.shape[2]
num_valid_pixels = tf.math.count_nonzero(mask)
assert (total_num_pixels == 12)
assert (num_valid_pixels == 9)
# - First test: ignore masks -> Consider all the pixels.
#   The loss should be equal to the following (the "2 *" comes from the two
#   samples being equal):
expected_loss = -2 * (log(pred_y_sm[0, 0, 0, 1]) + log(pred_y_sm[0, 0, 1, 1]) +
                      log(pred_y_sm[0, 0, 2, 1]) + log(pred_y_sm[0, 1, 0, 1]) +
                      log(pred_y_sm[0, 1, 1, 0]) + log(pred_y_sm[0, 1, 2, 0]))
assert (almost_equal(loss_cse_with_logits(labels, pred_y), expected_loss))
assert (almost_equal(loss_cse(labels, pred_y_sm), expected_loss))
assert (almost_equal(loss_cse_with_logits_over_batches(labels, pred_y),
                     expected_loss / total_num_pixels))
assert (almost_equal(loss_cse_over_batches(labels, pred_y_sm),
                     expected_loss / total_num_pixels))
# - Second test: with masks.
#   The loss should be equal to the same loss as above, but excluding pixels
#   (0, 1) in sample 0 and (0, 2) and (1, 1) in sample 1, which have labels 1,
#   1, and 0, respectively.
expected_loss_with_mask = expected_loss + (log(pred_y_sm[0, 0, 1, 1]) + log(
    pred_y_sm[1, 0, 2, 1]) + log(pred_y_sm[1, 1, 1, 0]))
assert (almost_equal(loss_cse_with_logits(labels, pred_y, sample_weight=mask),
                     expected_loss_with_mask))
assert (almost_equal(loss_cse(labels, pred_y_sm, sample_weight=mask),
                     expected_loss_with_mask))
# NOTE: also with the mask, the loss with `SUM_OVER_BATCH_SIZE` divides by the
# total number of pixels.
assert (almost_equal(
    loss_cse_with_logits_over_batches(labels, pred_y, sample_weight=mask),
    expected_loss_with_mask / total_num_pixels))
assert (almost_equal(
    loss_cse_over_batches(labels, pred_y_sm, sample_weight=mask),
    expected_loss_with_mask / total_num_pixels))
# - Last test: use `tf.boolean_mask`.
pred_y_masked = tf.boolean_mask(pred_y, mask)
pred_y_sm_masked = tf.boolean_mask(pred_y_sm, mask)
labels_masked = tf.boolean_mask(labels, mask)
# NOTE: if used only on the masked pixels, the standard cross-entropy loss
# (i.e., with `SUM_OVER_BATCH_SIZE` and without `sample_weight`) divides by the
# number of valid pixels.
num_valid_pixels = tf.cast(num_valid_pixels, dtype=float)
assert (almost_equal(
    loss_cse_with_logits_over_batches(labels_masked, pred_y_masked),
    expected_loss_with_mask / num_valid_pixels))
assert (almost_equal(loss_cse_over_batches(labels_masked, pred_y_sm_masked),
                     expected_loss_with_mask / num_valid_pixels))
hermannsblum commented 3 years ago

Wow, thanks for the investigation. The reduction should then probably be fixed.

Just to add, some other things that I noticed: