ShannonAI / dice_loss_for_NLP

The repo contains the code of the ACL2020 paper `Dice Loss for Data-imbalanced NLP Tasks`
Apache License 2.0
272 stars 39 forks source link

The mask related code in the Dice loss function is wrong #8

Open nikolakopoulos opened 3 years ago

nikolakopoulos commented 3 years ago

Hello,

First of all, cool work! :)

Now let me get to the point:

I found the following bug in your code:

    if mask is not None: # here is the problem!! flat_input and flat_target are already made one-hot, thus the multiplication will not work!
        mask = mask.float()
        flat_input = flat_input * mask
        flat_target = flat_target * mask
    else:
        mask = torch.ones_like(target)

An easy fix is the following:

    if mask is not None:
        mask = mask.float()
        flat_input = (flat_input.t() * mask).t()
        flat_target = (flat_target.t() * mask).t()
    else:
        mask = torch.ones_like(target)