I tried replacing BCE loss with DICE and my model wouldn't converge. When I looked closer I noticed that whilst the input and target are flattened, the mask isn't. So if you pass a mask that is the same shape as the target, then the multiplication flat_input * mask unflattens flat_input
def _binary_class(self, input, target, mask=None):
flat_input = input.view(-1)
flat_target = target.view(-1).float()
flat_input = torch.sigmoid(flat_input) if self.with_logits else flat_input
if mask is not None:
mask = mask.float()
flat_input = flat_input * mask
flat_target = flat_target * mask
else:
mask = torch.ones_like(target)
I made the following change and my model started converging immediately
if mask is not None:
mask = mask.float()
flat_input = flat_input * mask.view(-1)
flat_target = flat_target * mask.view(-1)
else:
mask = torch.ones_like(target)
Although I think a better fix is to actually apply the mask rather than mask out the masked inputs/targets ie.
I tried replacing BCE loss with DICE and my model wouldn't converge. When I looked closer I noticed that whilst the input and target are flattened, the mask isn't. So if you pass a mask that is the same shape as the target, then the multiplication
flat_input * mask
unflattensflat_input
I made the following change and my model started converging immediately
Although I think a better fix is to actually apply the mask rather than mask out the masked inputs/targets ie.