ShannonAI / dice_loss_for_NLP

The repo contains the code of the ACL2020 paper `Dice Loss for Data-imbalanced NLP Tasks`
Apache License 2.0
272 stars 39 forks source link

Dice loss for Token classification #11

Open Zhylkaaa opened 3 years ago

Zhylkaaa commented 3 years ago

I've been trying to use dice loss for task of token classification with 9 classes. after I have fixed few errors in _multiple_class for example in line 143 we have flat_input_idx.view(-1, 1) which throws an error because tensors are not contiguous. I used this instead: loss_idx = self._compute_dice_loss(flat_input_idx.reshape(-1, 1), flat_target_idx.reshape(-1, 1))

And now I've tried to train a model with this and it seems to me that loss isn't changing at all. I don't know what I am doing wrong https://github.com/Zhylkaaa/simpletransformers/blob/dice_loss/simpletransformers/ner/ner_model.py#L489 - this is where I am trying to integrate dice_loss.

I can prepare minimal example if you want to take a look

xiaoya-li commented 3 years ago

Hi, thanks for asking! We provide an example for integrating dice loss to the multi-class classification task. Please see https://github.com/ShannonAI/dice_loss_for_NLP/blob/master/tasks/tnews/train.py#L142 for more details.
To my understanding, the difference between this repo and yours is : we use a for loop https://github.com/ShannonAI/dice_loss_for_NLP/blob/master/loss/dice_loss.py#L140 for calculating multi-class loss; whereas https://github.com/Zhylkaaa/simpletransformers/blob/dice_loss/simpletransformers/ner/losses/dice_loss.py#L138 seems to miss that. I hope this helps.

Zhylkaaa commented 3 years ago

Thank you, I guess main issue was that optimiser is set to minimize loss while score itself was meant to be maximized :) But I also have a question about the alpha parameter in dice loss. Doesn't it artificially lowers the f1 score for a given class? I mean if we have 2 instances of some class and one of them is "easy" meaning that flat_input is ~1 we will and up in situation where we artificially lower the loss (and as a result learning rate) for second example. What I suggest is maybe we should also multiply targets by the same factor as inputs?