ShannonAI / dice_loss_for_NLP

The repo contains the code of the ACL2020 paper `Dice Loss for Data-imbalanced NLP Tasks`
Apache License 2.0
272 stars 39 forks source link

Masking #18

Open david-waterworth opened 3 years ago

david-waterworth commented 3 years ago

I tried replacing BCE loss with DICE and my model wouldn't converge. When I looked closer I noticed that whilst the input and target are flattened, the mask isn't. So if you pass a mask that is the same shape as the target, then the multiplication flat_input * mask unflattens flat_input

def _binary_class(self, input, target, mask=None):
    flat_input = input.view(-1)
    flat_target = target.view(-1).float()
    flat_input = torch.sigmoid(flat_input) if self.with_logits else flat_input

    if mask is not None:
        mask = mask.float()
        flat_input = flat_input * mask
        flat_target = flat_target * mask
    else:
        mask = torch.ones_like(target)

I made the following change and my model started converging immediately

    if mask is not None:
        mask = mask.float()
        flat_input = flat_input * mask.view(-1)
        flat_target = flat_target * mask.view(-1)
    else:
        mask = torch.ones_like(target)

Although I think a better fix is to actually apply the mask rather than mask out the masked inputs/targets ie.

        mask = mask.view(-1)
        flat_input = flat_input[mask]
        flat_target = flat_target[mask]