Open YeDeming opened 4 years ago
Hi, Thanks for your comments. For now, I will leave the code as it is, for reproducibility and consistency with the trained models. I will try to revisit and re-train the models in the future, when my computational resources allow it.
Hi Vid, As the email we discussed, I record the bug here.
The sum loss should be divided by the number of [MASK]. In current situation, the number [MASK] is larger, the mean loss is larger. I am not confirmed whether the length of correct answer makes bias in test set.
superfluous mean()
Hello, could you elaborate on this issue?
If I understand correctly what you mean is that an answer with more subtokens would have less chance to be selected as the right answer. So basically what you are suggesting is
return torch.mean(masked_lm_loss,1) / number_of_masks
?
Tks in advance!
That's correct. And the second bug is one redundant ".mean()" after loss_2 at the end of the line.
It should be - loss_2
without .mean() because .mean()
averages over the batch - which shouldn't happen at this stage.
return torch.mean(masked_lm_loss,1) / number_of_masks
@vid-koci @xiaoouwang
Just to clarify, should number_of_masks
be calculated with torch.count_nonzero(masked_lm_labels==-1, axis=1)
?
@mhillebrand I believe that it should be masked_lm_labels>-1
rather than masked_lm_labels==-1
as -1 are exactly the ones that we are not interested in and are ignored by the CrossEntropy loss.
Ah, of course. Line 72 should be:
return torch.sum(masked_lm_loss, 1) / torch.count_nonzero(masked_lm_labels>-1, axis=1)
Interesting. The accuracy drops by over 5% after fixing these two bugs.
Hi Vid, As the email we discussed, I record the bug here.
https://github.com/vid-koci/bert-commonsense/blob/2fb175493e402c6c8fae8e3fdb70378b2d9b009d/main.py#L71-L72
The sum loss should be divided by the number of [MASK]. In current situation, the number [MASK] is larger, the mean loss is larger. I am not confirmed whether the length of correct answer makes bias in test set.
https://github.com/vid-koci/bert-commonsense/blob/2fb175493e402c6c8fae8e3fdb70378b2d9b009d/main.py#L483
superfluous mean()