Let's assume 2D multi-category segmentation problem with batches are size [b, h, w, 1] and y_pred is [b,h,w,c] where c is number of classes.
Now, suppose you only have labels for some of the pixels, call this mask w and is of size [b,w,h] in {0,1} as an indicator variable if label not present / present.
Tf keras computes weighted metrics by loss() * w.
To have tf keras take advantage of this weight, you have to remove K.mean() from the end of the categorical loss. It should be just:
# Compute mean loss in mini_batch
return K.sum(loss, axis=-1)
This results in an output which is [b, w, h] which is the same as the shape of w.
Let's assume 2D multi-category segmentation problem with batches are size [b, h, w, 1] and y_pred is [b,h,w,c] where c is number of classes. Now, suppose you only have labels for some of the pixels, call this mask w and is of size [b,w,h] in {0,1} as an indicator variable if label not present / present.
Tf keras computes weighted metrics by loss() * w.
To have tf keras take advantage of this weight, you have to remove K.mean() from the end of the categorical loss. It should be just:
This results in an output which is [b, w, h] which is the same as the shape of w.