Closed claudiu-mihaila closed 4 years ago
Hi, I only encountered NaN loss with seq2seq_text/seq2seq_tag problems in multi-task scenario. cls, seq_tag worked fine in my experience. Did you encountered NaN losses in cls problem and the commit makes a difference?
Yes, I did encounter it while doing a cls multi-task. I think the multi-task aspect was the problem, as the empty tensor of a missing task in a batch produces NaN loss. This replacement of NaN with 0 fixes the issue.
Your fix seems good. Could you please help to apply the same patch to other top layers? Thanks.
NaN loss is generated when a batch does not contain instances from all tasks. For those tasks which are not present in a batch, the loss_multiplier will be an empty tensor, which produces NaN when the batch_loss is multiplied by the (empty) loss multiplier. Replace with 0 to fix the issue.