umbertogriffo / focal-loss-keras

Binary and Categorical Focal loss implementation in Keras.
278 stars 67 forks source link

A fix so that you can weight specific pixels as in a 2D segmentation problem with only partially / weakly labeled pixels. #18

Open isaacgerg opened 3 years ago

isaacgerg commented 3 years ago

Let's assume 2D multi-category segmentation problem with batches are size [b, h, w, 1] and y_pred is [b,h,w,c] where c is number of classes. Now, suppose you only have labels for some of the pixels, call this mask w and is of size [b,w,h] in {0,1} as an indicator variable if label not present / present.

Tf keras computes weighted metrics by loss() * w.

To have tf keras take advantage of this weight, you have to remove K.mean() from the end of the categorical loss. It should be just:

# Compute mean loss in mini_batch
return K.sum(loss, axis=-1)

This results in an output which is [b, w, h] which is the same as the shape of w.