Open francescomilano172 opened 3 years ago
Wow, thanks for the investigation. The reduction should then probably be fixed.
Just to add, some other things that I noticed:
from_logits=True
, passing not the softmax output but the logits directly (and internally computing log_softmax), which can be done for both options.
For reference, I have found two main ways in our code in which masked cross-entropy loss is used, and wanted to point out some subtle issues that may cause it to not evaluate it to what we would like, due to the way in which
tf.keras.losses.SparseCategoricalCrossentropy
is defined. Let's consider for example a batch with8
image samples of size200 x 300
in which there are in total100000
non-valid pixels out of the8 * 200 * 300 = 480000
total pixels.ignorant_cross_entropy_loss
, thesample_weight
argument with the boolean mask is passed when calling the instance of the loss here: https://github.com/ethz-asl/background_foreground_segmentation/blob/c9bb67a05ce03606861c8fc0e02e304d2b8a2e96/src/bfseg/utils/losses.py#L46 However, by default the classtf.keras.losses.SparseCategoricalCrossentropy
has thereduction
argument set toreduction=tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE
(cf. https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/keras/losses/SparseCategoricalCrossentropy#args). In practice, this means that the loss will take the sum of thelog
terms for the380000
valid pixels and divide it by480000
, which is the total number of pixels in the batch size. Since the number of masked pixels in general varies from one batch to the other, having a fixed scaling factor (1 / 480000
in the example) is probably not optimal. The alternative would be to setreduction=tf.keras.losses.Reduction.SUM
in the the constructor ofSparseCategoricalCrossentropy
, so that no scaling factor is used. Then one can manually divide by the number of valid pixels:SparseCategoricalCrossentropy
(withreduction=tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE
and withoutsample_weight
) directly on the subset of the valid pixels, and this will automatically divide by the number of valid pixels, as done here: https://github.com/ethz-asl/background_foreground_segmentation/blob/e4c04bc7f5ab101770656c07b99b064caaa8c824/src/train_binary_segmodel_base.py#L264-L266Small unit test to double check the things above: