JayYip / m3tl

BERT for Multitask Learning
https://jayyip.github.io/m3tl/
Apache License 2.0
545 stars 125 forks source link

fix NaN loss by replacing with 0 #35

Closed claudiu-mihaila closed 4 years ago

claudiu-mihaila commented 4 years ago

NaN loss is generated when a batch does not contain instances from all tasks. For those tasks which are not present in a batch, the loss_multiplier will be an empty tensor, which produces NaN when the batch_loss is multiplied by the (empty) loss multiplier. Replace with 0 to fix the issue.

JayYip commented 4 years ago

Hi, I only encountered NaN loss with seq2seq_text/seq2seq_tag problems in multi-task scenario. cls, seq_tag worked fine in my experience. Did you encountered NaN losses in cls problem and the commit makes a difference?

claudiu-mihaila commented 4 years ago

Yes, I did encounter it while doing a cls multi-task. I think the multi-task aspect was the problem, as the empty tensor of a missing task in a batch produces NaN loss. This replacement of NaN with 0 fixes the issue.

JayYip commented 4 years ago

Your fix seems good. Could you please help to apply the same patch to other top layers? Thanks.